From 886ac70a6cb92d2615e714e9ca2bfcec973065ca Mon Sep 17 00:00:00 2001 From: Durgadoss R Date: Thu, 29 Aug 2024 18:09:14 +0530 Subject: [PATCH] [APFloat] Add APFloat support for E8M0 type This patch adds an APFloat type for unsigned E8M0 format. This format is used for representing the "scale-format" in the MX specification: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf This format does not support {Inf, denorms, zeroes}. Like FP32, this format's exponents are 8-bits (all bits here) and the bias value is 127. However, it differs from IEEE-FP32 in that the minExponent is -127 (instead of -126). There are updates done in the APFloat utility functions to handle these constraints for this format. * The bias calculation is different and convertIEEE* APIs are updated to handle this. * Since there are no significand bits, the isSignificandAll{Zeroes/Ones} methods are updated accordingly. * Although the format does not have any precision, the precision bit in the fltSemantics is set to 1 for consistency with APFloat's internal representation. * Many utility functions are updated to handle the fact that this format does not support Zero. * Provide a separate initFromAPInt() implementation to handle the quirks of the format. * Add specific tests to verify the range of values for this format. Signed-off-by: Durgadoss R --- llvm/include/llvm/ADT/APFloat.h | 21 ++ llvm/lib/Support/APFloat.cpp | 158 +++++++++++--- llvm/unittests/ADT/APFloatTest.cpp | 324 +++++++++++++++++++++++++++++ 3 files changed, 478 insertions(+), 25 deletions(-) diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h index 9cc8369a0bf52b..469beb36ced5d9 100644 --- a/llvm/include/llvm/ADT/APFloat.h +++ b/llvm/include/llvm/ADT/APFloat.h @@ -195,6 +195,13 @@ struct APFloatBase { // improved range compared to half (16-bit) formats, at (potentially) // greater throughput than single precision (32-bit) formats. S_FloatTF32, + // 8-bit floating point number with (all the) 8 bits for the exponent + // like in FP32. There are no zeroes, no infinities, and no denormal values. + // This format has unsigned representation only. (U -> Unsigned only). + // NaN is represented with all bits set to 1. Bias is 127. + // This format represents the scale data type in the MX specification from: + // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + S_Float8E8M0FNU, // 6-bit floating point number with bit layout S1E3M2. Unlike IEEE-754 // types, there are no infinity or NaN values. The format is detailed in // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf @@ -229,6 +236,7 @@ struct APFloatBase { static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE; static const fltSemantics &Float8E3M4() LLVM_READNONE; static const fltSemantics &FloatTF32() LLVM_READNONE; + static const fltSemantics &Float8E8M0FNU() LLVM_READNONE; static const fltSemantics &Float6E3M2FN() LLVM_READNONE; static const fltSemantics &Float6E2M3FN() LLVM_READNONE; static const fltSemantics &Float4E2M1FN() LLVM_READNONE; @@ -591,6 +599,7 @@ class IEEEFloat final : public APFloatBase { unsigned int significandLSB() const; unsigned int significandMSB() const; void zeroSignificand(); + unsigned int getNumHighBits() const; /// Return true if the significand excluding the integral bit is all ones. bool isSignificandAllOnes() const; bool isSignificandAllOnesExceptLSB() const; @@ -652,6 +661,7 @@ class IEEEFloat final : public APFloatBase { APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const; APInt convertFloat8E3M4APFloatToAPInt() const; APInt convertFloatTF32APFloatToAPInt() const; + APInt convertFloat8E8M0FNUAPFloatToAPInt() const; APInt convertFloat6E3M2FNAPFloatToAPInt() const; APInt convertFloat6E2M3FNAPFloatToAPInt() const; APInt convertFloat4E2M1FNAPFloatToAPInt() const; @@ -672,6 +682,7 @@ class IEEEFloat final : public APFloatBase { void initFromFloat8E4M3B11FNUZAPInt(const APInt &api); void initFromFloat8E3M4APInt(const APInt &api); void initFromFloatTF32APInt(const APInt &api); + void initFromFloat8E8M0FNUAPInt(const APInt &api); void initFromFloat6E3M2FNAPInt(const APInt &api); void initFromFloat6E2M3FNAPInt(const APInt &api); void initFromFloat4E2M1FNAPInt(const APInt &api); @@ -1079,6 +1090,9 @@ class APFloat : public APFloatBase { /// \param Semantics - type float semantics static APFloat getAllOnesValue(const fltSemantics &Semantics); + /// Returns true if the given semantics supports either NaN or Infinity. + /// + /// \param Sem - type float semantics static bool hasNanOrInf(const fltSemantics &Sem) { switch (SemanticsToEnum(Sem)) { default: @@ -1091,6 +1105,13 @@ class APFloat : public APFloatBase { } } + /// Returns true if the given semantics has actual significand. + /// + /// \param Sem - type float semantics + static bool hasSignificand(const fltSemantics &Sem) { + return &Sem != &Float8E8M0FNU(); + } + /// Used to insert APFloat objects, or objects that contain APFloat objects, /// into FoldingSets. void Profile(FoldingSetNodeID &NID) const; diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp index dee917fd56104c..517f234f4be721 100644 --- a/llvm/lib/Support/APFloat.cpp +++ b/llvm/lib/Support/APFloat.cpp @@ -119,6 +119,13 @@ struct fltSemantics { fltNonfiniteBehavior nonFiniteBehavior = fltNonfiniteBehavior::IEEE754; fltNanEncoding nanEncoding = fltNanEncoding::IEEE; + + /* Whether this semantics has an encoding for Zero */ + bool hasZero = true; + + /* Whether this semantics can represent signed values */ + bool hasSignedRepr = true; + // Returns true if any number described by this semantics can be precisely // represented by the specified semantics. Does not take into account // the value of fltNonfiniteBehavior. @@ -145,6 +152,10 @@ static constexpr fltSemantics semFloat8E4M3B11FNUZ = { 4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero}; static constexpr fltSemantics semFloat8E3M4 = {3, -2, 5, 8}; static constexpr fltSemantics semFloatTF32 = {127, -126, 11, 19}; +static constexpr fltSemantics semFloat8E8M0FNU = { + 127, -127, 1, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes, + false, false}; + static constexpr fltSemantics semFloat6E3M2FN = { 4, -2, 3, 6, fltNonfiniteBehavior::FiniteOnly}; static constexpr fltSemantics semFloat6E2M3FN = { @@ -222,6 +233,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) { return Float8E3M4(); case S_FloatTF32: return FloatTF32(); + case S_Float8E8M0FNU: + return Float8E8M0FNU(); case S_Float6E3M2FN: return Float6E3M2FN(); case S_Float6E2M3FN: @@ -264,6 +277,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) { return S_Float8E3M4; else if (&Sem == &llvm::APFloat::FloatTF32()) return S_FloatTF32; + else if (&Sem == &llvm::APFloat::Float8E8M0FNU()) + return S_Float8E8M0FNU; else if (&Sem == &llvm::APFloat::Float6E3M2FN()) return S_Float6E3M2FN; else if (&Sem == &llvm::APFloat::Float6E2M3FN()) @@ -294,6 +309,7 @@ const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() { } const fltSemantics &APFloatBase::Float8E3M4() { return semFloat8E3M4; } const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; } +const fltSemantics &APFloatBase::Float8E8M0FNU() { return semFloat8E8M0FNU; } const fltSemantics &APFloatBase::Float6E3M2FN() { return semFloat6E3M2FN; } const fltSemantics &APFloatBase::Float6E2M3FN() { return semFloat6E2M3FN; } const fltSemantics &APFloatBase::Float4E2M1FN() { return semFloat4E2M1FN; } @@ -396,7 +412,8 @@ static inline Error createError(const Twine &Err) { } static constexpr inline unsigned int partCountForBits(unsigned int bits) { - return ((bits) + APFloatBase::integerPartWidth - 1) / APFloatBase::integerPartWidth; + return std::max(1u, (bits + APFloatBase::integerPartWidth - 1) / + APFloatBase::integerPartWidth); } /* Returns 0U-9U. Return values >= 10U are not digits. */ @@ -918,6 +935,10 @@ void IEEEFloat::makeNaN(bool SNaN, bool Negative, const APInt *fill) { if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly) llvm_unreachable("This floating point format does not support NaN"); + if (Negative && !semantics->hasSignedRepr) + llvm_unreachable( + "This floating point format does not support signed values"); + category = fcNaN; sign = Negative; exponent = exponentNaN(); @@ -955,7 +976,8 @@ void IEEEFloat::makeNaN(bool SNaN, bool Negative, const APInt *fill) { significand[part] = 0; } - unsigned QNaNBit = semantics->precision - 2; + unsigned QNaNBit = + (semantics->precision >= 2) ? (semantics->precision - 2) : 0; if (SNaN) { // We always have to clear the QNaN bit to make it an SNaN. @@ -1025,6 +1047,19 @@ bool IEEEFloat::isSmallestNormalized() const { isSignificandAllZerosExceptMSB(); } +unsigned int IEEEFloat::getNumHighBits() const { + const unsigned int PartCount = partCountForBits(semantics->precision); + const unsigned int Bits = PartCount * integerPartWidth; + + // Compute how many bits are used in the final word. + // When precision is just 1, it represents the 'Pth' + // Precision bit and not the actual significand bit. + const unsigned int NumHighBits = (semantics->precision > 1) + ? (Bits - semantics->precision + 1) + : (Bits - semantics->precision); + return NumHighBits; +} + bool IEEEFloat::isSignificandAllOnes() const { // Test if the significand excluding the integral bit is all ones. This allows // us to test for binade boundaries. @@ -1035,13 +1070,12 @@ bool IEEEFloat::isSignificandAllOnes() const { return false; // Set the unused high bits to all ones when we compare. - const unsigned NumHighBits = - PartCount*integerPartWidth - semantics->precision + 1; + const unsigned NumHighBits = getNumHighBits(); assert(NumHighBits <= integerPartWidth && NumHighBits > 0 && "Can not have more high bits to fill than integerPartWidth"); const integerPart HighBitFill = ~integerPart(0) << (integerPartWidth - NumHighBits); - if (~(Parts[PartCount - 1] | HighBitFill)) + if ((semantics->precision <= 1) || (~(Parts[PartCount - 1] | HighBitFill))) return false; return true; @@ -1062,8 +1096,7 @@ bool IEEEFloat::isSignificandAllOnesExceptLSB() const { } // Set the unused high bits to all ones when we compare. - const unsigned NumHighBits = - PartCount * integerPartWidth - semantics->precision + 1; + const unsigned NumHighBits = getNumHighBits(); assert(NumHighBits <= integerPartWidth && NumHighBits > 0 && "Can not have more high bits to fill than integerPartWidth"); const integerPart HighBitFill = ~integerPart(0) @@ -1085,13 +1118,12 @@ bool IEEEFloat::isSignificandAllZeros() const { return false; // Compute how many bits are used in the final word. - const unsigned NumHighBits = - PartCount*integerPartWidth - semantics->precision + 1; + const unsigned NumHighBits = getNumHighBits(); assert(NumHighBits < integerPartWidth && "Can not have more high bits to " "clear than integerPartWidth"); const integerPart HighBitMask = ~integerPart(0) >> NumHighBits; - if (Parts[PartCount - 1] & HighBitMask) + if ((semantics->precision > 1) && (Parts[PartCount - 1] & HighBitMask)) return false; return true; @@ -1106,25 +1138,26 @@ bool IEEEFloat::isSignificandAllZerosExceptMSB() const { return false; } - const unsigned NumHighBits = - PartCount * integerPartWidth - semantics->precision + 1; - return Parts[PartCount - 1] == integerPart(1) - << (integerPartWidth - NumHighBits); + const unsigned NumHighBits = getNumHighBits(); + const integerPart MSBMask = integerPart(1) + << (integerPartWidth - NumHighBits); + return ((semantics->precision <= 1) || (Parts[PartCount - 1] == MSBMask)); } bool IEEEFloat::isLargest() const { + bool IsMaxExp = isFiniteNonZero() && exponent == semantics->maxExponent; if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly && semantics->nanEncoding == fltNanEncoding::AllOnes) { // The largest number by magnitude in our format will be the floating point // number with maximum exponent and with significand that is all ones except // the LSB. - return isFiniteNonZero() && exponent == semantics->maxExponent && - isSignificandAllOnesExceptLSB(); + return (IsMaxExp && APFloat::hasSignificand(*semantics)) + ? isSignificandAllOnesExceptLSB() + : IsMaxExp; } else { // The largest number by magnitude in our format will be the floating point // number with maximum exponent and with significand that is all ones. - return isFiniteNonZero() && exponent == semantics->maxExponent && - isSignificandAllOnes(); + return IsMaxExp && isSignificandAllOnes(); } } @@ -1165,7 +1198,13 @@ IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics, integerPart value) { IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics) { initialize(&ourSemantics); - makeZero(false); + // The Float8E8MOFNU format does not have a representation + // for zero. So, use the closest representation instead. + // Moreover, the all-zero encoding represents a valid + // normal value (which is the smallestNormalized here). + // Hence, we call makeSmallestNormalized (where category is + // 'fcNormal') instead of makeZero (where category is 'fcZero'). + ourSemantics.hasZero ? makeZero(false) : makeSmallestNormalized(false); } // Delegate to the previous constructor, because later copy constructor may @@ -1729,6 +1768,11 @@ IEEEFloat::opStatus IEEEFloat::normalize(roundingMode rounding_mode, category = fcZero; if (semantics->nanEncoding == fltNanEncoding::NegativeZero) sign = false; + // This condition handles the case where the semantics + // does not have zero but uses the all-zero encoding + // to represent the smallest normal value. + if (!semantics->hasZero) + makeSmallestNormalized(false); } /* The fcZero case is a denormal that underflowed to zero. */ @@ -2606,6 +2650,8 @@ IEEEFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics, fs = opOK; } + if (category == fcZero && !semantics->hasZero) + makeSmallestNormalized(false); return fs; } @@ -3070,6 +3116,8 @@ IEEEFloat::convertFromDecimalString(StringRef str, roundingMode rounding_mode) { fs = opOK; if (semantics->nanEncoding == fltNanEncoding::NegativeZero) sign = false; + if (!semantics->hasZero) + makeSmallestNormalized(false); /* Check whether the normalized exponent is high enough to overflow max during the log-rebasing in the max-exponent check below. */ @@ -3237,6 +3285,10 @@ IEEEFloat::convertFromString(StringRef str, roundingMode rounding_mode) { StringRef::iterator p = str.begin(); size_t slen = str.size(); sign = *p == '-' ? 1 : 0; + if (sign && !semantics->hasSignedRepr) + llvm_unreachable( + "This floating point format does not support signed values"); + if (*p == '-' || *p == '+') { p++; slen--; @@ -3533,15 +3585,16 @@ APInt IEEEFloat::convertPPCDoubleDoubleAPFloatToAPInt() const { template APInt IEEEFloat::convertIEEEFloatToAPInt() const { assert(semantics == &S); - - constexpr int bias = -(S.minExponent - 1); + const int bias = + (semantics == &semFloat8E8M0FNU) ? -S.minExponent : -(S.minExponent - 1); constexpr unsigned int trailing_significand_bits = S.precision - 1; constexpr int integer_bit_part = trailing_significand_bits / integerPartWidth; constexpr integerPart integer_bit = integerPart{1} << (trailing_significand_bits % integerPartWidth); constexpr uint64_t significand_mask = integer_bit - 1; constexpr unsigned int exponent_bits = - S.sizeInBits - 1 - trailing_significand_bits; + trailing_significand_bits ? (S.sizeInBits - 1 - trailing_significand_bits) + : S.sizeInBits; static_assert(exponent_bits < 64); constexpr uint64_t exponent_mask = (uint64_t{1} << exponent_bits) - 1; @@ -3557,6 +3610,8 @@ APInt IEEEFloat::convertIEEEFloatToAPInt() const { !(significandParts()[integer_bit_part] & integer_bit)) myexponent = 0; // denormal } else if (category == fcZero) { + if (!S.hasZero) + llvm_unreachable("semantics does not support zero!"); myexponent = ::exponentZero(S) + bias; mysignificand.fill(0); } else if (category == fcInfinity) { @@ -3659,6 +3714,11 @@ APInt IEEEFloat::convertFloatTF32APFloatToAPInt() const { return convertIEEEFloatToAPInt(); } +APInt IEEEFloat::convertFloat8E8M0FNUAPFloatToAPInt() const { + assert(partCount() == 1); + return convertIEEEFloatToAPInt(); +} + APInt IEEEFloat::convertFloat6E3M2FNAPFloatToAPInt() const { assert(partCount() == 1); return convertIEEEFloatToAPInt(); @@ -3721,6 +3781,9 @@ APInt IEEEFloat::bitcastToAPInt() const { if (semantics == (const llvm::fltSemantics *)&semFloatTF32) return convertFloatTF32APFloatToAPInt(); + if (semantics == (const llvm::fltSemantics *)&semFloat8E8M0FNU) + return convertFloat8E8M0FNUAPFloatToAPInt(); + if (semantics == (const llvm::fltSemantics *)&semFloat6E3M2FN) return convertFloat6E3M2FNAPFloatToAPInt(); @@ -3819,6 +3882,40 @@ void IEEEFloat::initFromPPCDoubleDoubleAPInt(const APInt &api) { } } +// The E8M0 format has the following characteristics: +// It is an 8-bit unsigned format with only exponents (no actual significand). +// No encodings for {zero, infinities or denorms}. +// NaN is represented by all 1's. +// Bias is 127. +void IEEEFloat::initFromFloat8E8M0FNUAPInt(const APInt &api) { + const uint64_t exponent_mask = 0xff; + uint64_t val = api.getRawData()[0]; + uint64_t myexponent = (val & exponent_mask); + + initialize(&semFloat8E8M0FNU); + assert(partCount() == 1); + + // This format has unsigned representation only + sign = 0; + + // Set the significand + // This format does not have any significand but the 'Pth' precision bit is + // always set to 1 for consistency in APFloat's internal representation. + uint64_t mysignificand = 1; + significandParts()[0] = mysignificand; + + // This format can either have a NaN or fcNormal + // All 1's i.e. 255 is a NaN + if (val == exponent_mask) { + category = fcNaN; + exponent = exponentNaN(); + return; + } + // Handle fcNormal... + category = fcNormal; + exponent = myexponent - 127; // 127 is bias + return; +} template void IEEEFloat::initFromIEEEAPInt(const APInt &api) { assert(api.getBitWidth() == S.sizeInBits); @@ -3999,6 +4096,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) { return initFromFloat8E3M4APInt(api); if (Sem == &semFloatTF32) return initFromFloatTF32APInt(api); + if (Sem == &semFloat8E8M0FNU) + return initFromFloat8E8M0FNUAPInt(api); if (Sem == &semFloat6E3M2FN) return initFromFloat6E3M2FNAPInt(api); if (Sem == &semFloat6E2M3FN) @@ -4032,9 +4131,9 @@ void IEEEFloat::makeLargest(bool Negative) { significand[PartCount - 1] = (NumUnusedHighBits < integerPartWidth) ? (~integerPart(0) >> NumUnusedHighBits) : 0; - if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly && - semantics->nanEncoding == fltNanEncoding::AllOnes) + semantics->nanEncoding == fltNanEncoding::AllOnes && + (semantics->precision > 1)) significand[0] &= ~integerPart(1); } @@ -4509,6 +4608,8 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) { exponent = 0; if (semantics->nanEncoding == fltNanEncoding::NegativeZero) sign = false; + if (!semantics->hasZero) + makeSmallestNormalized(false); break; } @@ -4574,7 +4675,11 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) { // the integral bit to 1, and increment the exponent. If we have a // denormal always increment since moving denormals and the numbers in the // smallest normal binade have the same exponent in our representation. - bool WillCrossBinadeBoundary = !isDenormal() && isSignificandAllOnes(); + // If there are only exponents, any increment always crosses the + // BinadeBoundary. + bool WillCrossBinadeBoundary = true; + if (APFloat::hasSignificand(*semantics)) + WillCrossBinadeBoundary = !isDenormal() && isSignificandAllOnes(); if (WillCrossBinadeBoundary) { integerPart *Parts = significandParts(); @@ -4626,6 +4731,9 @@ void IEEEFloat::makeInf(bool Negative) { } void IEEEFloat::makeZero(bool Negative) { + if (!semantics->hasZero) + llvm_unreachable("This floating point format does not support Zero"); + category = fcZero; sign = Negative; if (semantics->nanEncoding == fltNanEncoding::NegativeZero) { diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp index 6c49d78e5c8ea9..605e14ad657979 100644 --- a/llvm/unittests/ADT/APFloatTest.cpp +++ b/llvm/unittests/ADT/APFloatTest.cpp @@ -814,6 +814,11 @@ TEST(APFloatTest, IsSmallestNormalized) { const fltSemantics &Semantics = APFloat::EnumToSemantics(static_cast(I)); + // For Float8E8M0FNU format, the below cases are tested + // through Float8E8M0FNUSmallest and Float8E8M0FNUNext tests. + if (I == APFloat::S_Float8E8M0FNU) + continue; + EXPECT_FALSE(APFloat::getZero(Semantics, false).isSmallestNormalized()); EXPECT_FALSE(APFloat::getZero(Semantics, true).isSmallestNormalized()); @@ -1907,6 +1912,57 @@ TEST(DoubleAPFloatTest, isInteger) { EXPECT_FALSE(T3.isInteger()); } +// Test to check if the full range of Float8E8M0FNU +// values are being represented correctly. +TEST(APFloatTest, Float8E8M0FNUValues) { + // High end of the range + auto test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p127"); + EXPECT_EQ(0x1.0p127, test.convertToDouble()); + + test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p126"); + EXPECT_EQ(0x1.0p126, test.convertToDouble()); + + test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p125"); + EXPECT_EQ(0x1.0p125, test.convertToDouble()); + + // tests the fix in makeLargest() + test = APFloat::getLargest(APFloat::Float8E8M0FNU()); + EXPECT_EQ(0x1.0p127, test.convertToDouble()); + + // tests overflow to nan + APFloat nan = APFloat(APFloat::Float8E8M0FNU(), "nan"); + test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p128"); + EXPECT_TRUE(test.bitwiseIsEqual(nan)); + + // Mid of the range + test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p0"); + EXPECT_EQ(1.0, test.convertToDouble()); + + test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p1"); + EXPECT_EQ(2.0, test.convertToDouble()); + + test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p2"); + EXPECT_EQ(4.0, test.convertToDouble()); + + // Low end of the range + test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p-125"); + EXPECT_EQ(0x1.0p-125, test.convertToDouble()); + + test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p-126"); + EXPECT_EQ(0x1.0p-126, test.convertToDouble()); + + test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p-127"); + EXPECT_EQ(0x1.0p-127, test.convertToDouble()); + + // Smallest value + test = APFloat::getSmallest(APFloat::Float8E8M0FNU()); + EXPECT_EQ(0x1.0p-127, test.convertToDouble()); + + // Value below the smallest, but clamped to the smallest + test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p-128"); + EXPECT_EQ(0x1.0p-127, test.convertToDouble()); +} + TEST(APFloatTest, getLargest) { EXPECT_EQ(3.402823466e+38f, APFloat::getLargest(APFloat::IEEEsingle()).convertToFloat()); EXPECT_EQ(1.7976931348623158e+308, APFloat::getLargest(APFloat::IEEEdouble()).convertToDouble()); @@ -1919,6 +1975,8 @@ TEST(APFloatTest, getLargest) { 30, APFloat::getLargest(APFloat::Float8E4M3B11FNUZ()).convertToDouble()); EXPECT_EQ(3.40116213421e+38f, APFloat::getLargest(APFloat::FloatTF32()).convertToFloat()); + EXPECT_EQ(1.701411834e+38f, + APFloat::getLargest(APFloat::Float8E8M0FNU()).convertToDouble()); EXPECT_EQ(28, APFloat::getLargest(APFloat::Float6E3M2FN()).convertToDouble()); EXPECT_EQ(7.5, APFloat::getLargest(APFloat::Float6E2M3FN()).convertToDouble()); @@ -2002,6 +2060,13 @@ TEST(APFloatTest, getSmallest) { EXPECT_TRUE(test.isFiniteNonZero()); EXPECT_TRUE(test.isDenormal()); EXPECT_TRUE(test.bitwiseIsEqual(expected)); + + test = APFloat::getSmallest(APFloat::Float8E8M0FNU()); + expected = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p-127"); + EXPECT_FALSE(test.isNegative()); + EXPECT_TRUE(test.isFiniteNonZero()); + EXPECT_FALSE(test.isDenormal()); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); } TEST(APFloatTest, getSmallestNormalized) { @@ -2108,6 +2173,14 @@ TEST(APFloatTest, getSmallestNormalized) { EXPECT_FALSE(test.isDenormal()); EXPECT_TRUE(test.bitwiseIsEqual(expected)); EXPECT_TRUE(test.isSmallestNormalized()); + + test = APFloat::getSmallestNormalized(APFloat::Float8E8M0FNU(), false); + expected = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p-127"); + EXPECT_FALSE(test.isNegative()); + EXPECT_TRUE(test.isFiniteNonZero()); + EXPECT_FALSE(test.isDenormal()); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + EXPECT_TRUE(test.isSmallestNormalized()); } TEST(APFloatTest, getZero) { @@ -5791,6 +5864,46 @@ TEST(APFloatTest, Float8E4M3FNExhaustive) { } } +TEST(APFloatTest, Float8E8M0FNUExhaustive) { + // Test each of the 256 Float8E8M0FNU values. + for (int i = 0; i < 256; i++) { + APFloat test(APFloat::Float8E8M0FNU(), APInt(8, i)); + SCOPED_TRACE("i=" + std::to_string(i)); + + // isLargest + if (i == 254) { + EXPECT_TRUE(test.isLargest()); + EXPECT_EQ(abs(test).convertToDouble(), 0x1.0p127); + } else { + EXPECT_FALSE(test.isLargest()); + } + + // isSmallest + if (i == 0) { + EXPECT_TRUE(test.isSmallest()); + EXPECT_EQ(abs(test).convertToDouble(), 0x1.0p-127); + } else { + EXPECT_FALSE(test.isSmallest()); + } + + // convert to Double + bool losesInfo; + std::string val = std::to_string(i - 127); // 127 is the bias + llvm::SmallString<16> str("0x1.0p"); + str += val; + APFloat test2(APFloat::IEEEdouble(), str); + + APFloat::opStatus status = test.convert( + APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &losesInfo); + EXPECT_EQ(status, APFloat::opOK); + EXPECT_FALSE(losesInfo); + if (i == 255) + EXPECT_TRUE(test.isNaN()); + else + EXPECT_EQ(test.convertToDouble(), test2.convertToDouble()); + } +} + TEST(APFloatTest, Float8E5M2FNUZNext) { APFloat test(APFloat::Float8E5M2FNUZ(), APFloat::uninitialized); APFloat expected(APFloat::Float8E5M2FNUZ(), APFloat::uninitialized); @@ -7067,6 +7180,12 @@ TEST(APFloatTest, getExactLog2) { auto SemEnum = static_cast(I); const fltSemantics &Semantics = APFloat::EnumToSemantics(SemEnum); + // For the Float8E8M0FNU format, the below cases along + // with some more corner cases are tested through + // Float8E8M0FNUGetExactLog2. + if (I == APFloat::S_Float8E8M0FNU) + continue; + APFloat One(Semantics, "1.0"); if (I == APFloat::S_PPCDoubleDouble) { @@ -7136,6 +7255,211 @@ TEST(APFloatTest, getExactLog2) { } } +TEST(APFloatTest, Float8E8M0FNUGetZero) { +#ifdef GTEST_HAS_DEATH_TEST +#ifndef NDEBUG + EXPECT_DEATH(APFloat::getZero(APFloat::Float8E8M0FNU(), false), + "This floating point format does not support Zero"); + EXPECT_DEATH(APFloat::getZero(APFloat::Float8E8M0FNU(), true), + "This floating point format does not support Zero"); +#endif +#endif +} + +TEST(APFloatTest, Float8E8M0FNUGetSignedValues) { +#ifdef GTEST_HAS_DEATH_TEST +#ifndef NDEBUG + EXPECT_DEATH(APFloat(APFloat::Float8E8M0FNU(), "-64"), + "This floating point format does not support signed values"); + EXPECT_DEATH(APFloat(APFloat::Float8E8M0FNU(), "-0x1.0p128"), + "This floating point format does not support signed values"); + EXPECT_DEATH(APFloat::getNaN(APFloat::Float8E8M0FNU(), true), + "This floating point format does not support signed values"); +#endif +#endif +} + +TEST(APFloatTest, Float8E8M0FNUGetInf) { + // The E8M0 format does not support infinity and the + // all ones representation is treated as NaN. + APFloat t = APFloat::getInf(APFloat::Float8E8M0FNU()); + EXPECT_TRUE(t.isNaN()); + EXPECT_FALSE(t.isInfinity()); +} + +TEST(APFloatTest, Float8E8M0FNUFromString) { + // Exactly representable + EXPECT_EQ(64, APFloat(APFloat::Float8E8M0FNU(), "64").convertToDouble()); + // Overflow to NaN + EXPECT_TRUE(APFloat(APFloat::Float8E8M0FNU(), "0x1.0p128").isNaN()); + // Inf converted to NaN + EXPECT_TRUE(APFloat(APFloat::Float8E8M0FNU(), "inf").isNaN()); + // NaN converted to NaN + EXPECT_TRUE(APFloat(APFloat::Float8E8M0FNU(), "nan").isNaN()); +} + +TEST(APFloatTest, Float8E8M0FNUDivideByZero) { + APFloat x(APFloat::Float8E8M0FNU(), "1"); + APFloat zero(APFloat::Float8E8M0FNU(), "0"); + x.divide(zero, APFloat::rmNearestTiesToEven); + + // Zero is represented as the smallest normalized value + // in this format i.e 2^-127. + // This tests the fix in convertFromDecimalString() function. + EXPECT_EQ(0x1.0p-127, zero.convertToDouble()); + + // [1 / (2^-127)] = 2^127 + EXPECT_EQ(0x1.0p127, x.convertToDouble()); +} + +TEST(APFloatTest, Float8E8M0FNUGetExactLog2) { + const fltSemantics &Semantics = APFloat::Float8E8M0FNU(); + APFloat One(Semantics, "1.0"); + EXPECT_EQ(0, One.getExactLog2()); + + // In the Float8E8M0FNU format, 3 is rounded-up to 4. + // So, we expect 2 as the result. + EXPECT_EQ(2, APFloat(Semantics, "3.0").getExactLog2()); + EXPECT_EQ(2, APFloat(Semantics, "3.0").getExactLog2Abs()); + + // In the Float8E8M0FNU format, 5 is rounded-down to 4. + // So, we expect 2 as the result. + EXPECT_EQ(2, APFloat(Semantics, "5.0").getExactLog2()); + EXPECT_EQ(2, APFloat(Semantics, "5.0").getExactLog2Abs()); + + // Exact power-of-two value. + EXPECT_EQ(3, APFloat(Semantics, "8.0").getExactLog2()); + EXPECT_EQ(3, APFloat(Semantics, "8.0").getExactLog2Abs()); + + // Negative exponent value. + EXPECT_EQ(-2, APFloat(Semantics, "0.25").getExactLog2()); + EXPECT_EQ(-2, APFloat(Semantics, "0.25").getExactLog2Abs()); + + int MinExp = APFloat::semanticsMinExponent(Semantics); + int MaxExp = APFloat::semanticsMaxExponent(Semantics); + int Precision = APFloat::semanticsPrecision(Semantics); + + // Values below the minExp getting capped to minExp. + EXPECT_EQ(-127, + scalbn(One, MinExp - Precision - 1, APFloat::rmNearestTiesToEven) + .getExactLog2()); + EXPECT_EQ(-127, scalbn(One, MinExp - Precision, APFloat::rmNearestTiesToEven) + .getExactLog2()); + + // Values above the maxExp overflow to NaN, and getExactLog2() returns + // INT_MIN for these cases. + EXPECT_EQ( + INT_MIN, + scalbn(One, MaxExp + 1, APFloat::rmNearestTiesToEven).getExactLog2()); + + // This format can represent [minExp, maxExp]. + // So, the result is the same as the 'Exp' of the scalbn. + for (int i = MinExp - Precision + 1; i <= MaxExp; ++i) { + EXPECT_EQ(i, scalbn(One, i, APFloat::rmNearestTiesToEven).getExactLog2()); + } +} + +TEST(APFloatTest, Float8E8M0FNUSmallest) { + APFloat test(APFloat::getSmallest(APFloat::Float8E8M0FNU())); + EXPECT_EQ(0x1.0p-127, test.convertToDouble()); + + // For E8M0 format, there are no denorms. + // So, getSmallest is equal to isSmallestNormalized(). + EXPECT_TRUE(test.isSmallestNormalized()); + EXPECT_EQ(fcPosNormal, test.classify()); + + test = APFloat::getAllOnesValue(APFloat::Float8E8M0FNU()); + EXPECT_FALSE(test.isSmallestNormalized()); + EXPECT_TRUE(test.isNaN()); +} + +TEST(APFloatTest, Float8E8M0FNUNext) { + APFloat test(APFloat::getSmallest(APFloat::Float8E8M0FNU())); + // Increment of 1 should reach 2^-126 + EXPECT_EQ(APFloat::opOK, test.next(false)); + EXPECT_FALSE(test.isSmallestNormalized()); + EXPECT_EQ(0x1.0p-126, test.convertToDouble()); + + // Decrement of 1, again, should reach 2^-127 + // i.e. smallest normalized + EXPECT_EQ(APFloat::opOK, test.next(true)); + EXPECT_TRUE(test.isSmallestNormalized()); + + // Decrement again, but gets capped at the smallest normalized + EXPECT_EQ(APFloat::opOK, test.next(true)); + EXPECT_TRUE(test.isSmallestNormalized()); +} + +TEST(APFloatTest, ConvertDoubleToE8M0FN) { + bool losesInfo; + APFloat test(APFloat::IEEEdouble(), "1.0"); + APFloat::opStatus status = test.convert( + APFloat::Float8E8M0FNU(), APFloat::rmNearestTiesToEven, &losesInfo); + EXPECT_EQ(1.0, test.convertToDouble()); + EXPECT_FALSE(losesInfo); + EXPECT_EQ(status, APFloat::opOK); + + // For E8M0, zero encoding is represented as the smallest normalized value. + test = APFloat(APFloat::IEEEdouble(), "0.0"); + status = test.convert(APFloat::Float8E8M0FNU(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(test.isSmallestNormalized()); + EXPECT_EQ(0x1.0p-127, test.convertToDouble()); + EXPECT_FALSE(losesInfo); + EXPECT_EQ(status, APFloat::opOK); + + // Test that the conversion of a power-of-two value is precise. + test = APFloat(APFloat::IEEEdouble(), "8.0"); + status = test.convert(APFloat::Float8E8M0FNU(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(8.0f, test.convertToDouble()); + EXPECT_FALSE(losesInfo); + EXPECT_EQ(status, APFloat::opOK); + + // Test to check round-down conversion to power-of-two. + // The fractional part of 9 is "001" (i.e. 1.125x2^3=9). + test = APFloat(APFloat::IEEEdouble(), "9.0"); + status = test.convert(APFloat::Float8E8M0FNU(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(8.0f, test.convertToDouble()); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opInexact); + + // Test to check round-up conversion to power-of-two. + // The fractional part of 13 is "101" (i.e. 1.625x2^3=13). + test = APFloat(APFloat::IEEEdouble(), "13.0"); + status = test.convert(APFloat::Float8E8M0FNU(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(16.0f, test.convertToDouble()); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opInexact); + + // Test to check round-up conversion to power-of-two. + // The fractional part of 12 is "100" (i.e. 1.5x2^3=12). + test = APFloat(APFloat::IEEEdouble(), "12.0"); + status = test.convert(APFloat::Float8E8M0FNU(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(16.0f, test.convertToDouble()); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opInexact); + + // Overflow to NaN. + test = APFloat(APFloat::IEEEdouble(), "0x1.0p128"); + status = test.convert(APFloat::Float8E8M0FNU(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(test.isNaN()); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opOverflow | APFloat::opInexact); + + // Underflow to smallest normalized value. + test = APFloat(APFloat::IEEEdouble(), "0x1.0p-128"); + status = test.convert(APFloat::Float8E8M0FNU(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(test.isSmallestNormalized()); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opUnderflow | APFloat::opInexact); +} + TEST(APFloatTest, Float6E3M2FNFromString) { // Exactly representable EXPECT_EQ(28, APFloat(APFloat::Float6E3M2FN(), "28").convertToDouble());