Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[APFloat] Add APFloat support for E8M0 type #107127

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

durga4github
Copy link
Contributor

This patch adds an APFloat type for unsigned E8M0 format. This format is used for representing the "scale-format" in the MX specification: (section 5.4)
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

This format does not support {Inf, denorms, zeroes}. Like FP32, this format's exponents are 8-bits (all bits here) and the bias value is 127. However, it differs from IEEE-FP32 in that the minExponent is -127 (instead of -126). There are updates done in the APFloat utility functions to handle these constraints for this format.

  • The bias calculation is different and convertIEEE* APIs are updated to handle this.
  • Since there are no significand bits, the isSignificandAll{Zeroes/Ones} methods are updated accordingly.
  • Although the format does not have any precision, the precision bit in the fltSemantics is set to 1 for consistency with APFloat's internal representation.
  • Many utility functions are updated to handle the fact that this format does not support Zero.
  • Provide a separate initFromAPInt() implementation to handle the quirks of the format.
  • Add specific tests to verify the range of values for this format.

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 3, 2024

@llvm/pr-subscribers-llvm-support

@llvm/pr-subscribers-llvm-adt

Author: Durgadoss R (durga4github)

Changes

This patch adds an APFloat type for unsigned E8M0 format. This format is used for representing the "scale-format" in the MX specification: (section 5.4)
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

This format does not support {Inf, denorms, zeroes}. Like FP32, this format's exponents are 8-bits (all bits here) and the bias value is 127. However, it differs from IEEE-FP32 in that the minExponent is -127 (instead of -126). There are updates done in the APFloat utility functions to handle these constraints for this format.

  • The bias calculation is different and convertIEEE* APIs are updated to handle this.
  • Since there are no significand bits, the isSignificandAll{Zeroes/Ones} methods are updated accordingly.
  • Although the format does not have any precision, the precision bit in the fltSemantics is set to 1 for consistency with APFloat's internal representation.
  • Many utility functions are updated to handle the fact that this format does not support Zero.
  • Provide a separate initFromAPInt() implementation to handle the quirks of the format.
  • Add specific tests to verify the range of values for this format.

Patch is 28.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/107127.diff

3 Files Affected:

  • (modified) llvm/include/llvm/ADT/APFloat.h (+29)
  • (modified) llvm/lib/Support/APFloat.cpp (+124-4)
  • (modified) llvm/unittests/ADT/APFloatTest.cpp (+221-18)
diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index 7039e961bff82d..18b8b878c80161 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -195,6 +195,12 @@ struct APFloatBase {
     // improved range compared to half (16-bit) formats, at (potentially)
     // greater throughput than single precision (32-bit) formats.
     S_FloatTF32,
+    // 8-bit floating point number with (all the) 8 bits for the exponent
+    // like in FP32. There are no zeroes, no infinities, and no denormal values.
+    // NaN is represented with all bits set to 1. Bias is 127.
+    // This represents the scale data type in the MX specification from
+    // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+    S_Float8E8M0FN,
     // 6-bit floating point number with bit layout S1E3M2. Unlike IEEE-754
     // types, there are no infinity or NaN values. The format is detailed in
     // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
@@ -229,6 +235,7 @@ struct APFloatBase {
   static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE;
   static const fltSemantics &Float8E3M4() LLVM_READNONE;
   static const fltSemantics &FloatTF32() LLVM_READNONE;
+  static const fltSemantics &Float8E8M0FN() LLVM_READNONE;
   static const fltSemantics &Float6E3M2FN() LLVM_READNONE;
   static const fltSemantics &Float6E2M3FN() LLVM_READNONE;
   static const fltSemantics &Float4E2M1FN() LLVM_READNONE;
@@ -652,6 +659,7 @@ class IEEEFloat final : public APFloatBase {
   APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const;
   APInt convertFloat8E3M4APFloatToAPInt() const;
   APInt convertFloatTF32APFloatToAPInt() const;
+  APInt convertFloat8E8M0FNAPFloatToAPInt() const;
   APInt convertFloat6E3M2FNAPFloatToAPInt() const;
   APInt convertFloat6E2M3FNAPFloatToAPInt() const;
   APInt convertFloat4E2M1FNAPFloatToAPInt() const;
@@ -672,6 +680,7 @@ class IEEEFloat final : public APFloatBase {
   void initFromFloat8E4M3B11FNUZAPInt(const APInt &api);
   void initFromFloat8E3M4APInt(const APInt &api);
   void initFromFloatTF32APInt(const APInt &api);
+  void initFromFloat8E8M0FNAPInt(const APInt &api);
   void initFromFloat6E3M2FNAPInt(const APInt &api);
   void initFromFloat6E2M3FNAPInt(const APInt &api);
   void initFromFloat4E2M1FNAPInt(const APInt &api);
@@ -1091,6 +1100,26 @@ class APFloat : public APFloatBase {
     }
   }
 
+  static bool hasZero(const fltSemantics &Sem) {
+    switch (SemanticsToEnum(Sem)) {
+    default:
+      return true;
+    // The Float8E8M0FN does not have an encoding for Zeroes.
+    case APFloat::S_Float8E8M0FN:
+      return false;
+    }
+  }
+
+  static bool hasExponentOnly(const fltSemantics &Sem) {
+    switch (SemanticsToEnum(Sem)) {
+    default:
+      return false;
+    // The Float8E8M0FN has exponent only and no significand.
+    case APFloat::S_Float8E8M0FN:
+      return true;
+    }
+  }
+
   /// Used to insert APFloat objects, or objects that contain APFloat objects,
   /// into FoldingSets.
   void Profile(FoldingSetNodeID &NID) const;
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index 7f68c5ab9b7cf7..4c71f4eca53640 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -145,6 +145,9 @@ static constexpr fltSemantics semFloat8E4M3B11FNUZ = {
     4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
 static constexpr fltSemantics semFloat8E3M4 = {3, -2, 5, 8};
 static constexpr fltSemantics semFloatTF32 = {127, -126, 11, 19};
+static constexpr fltSemantics semFloat8E8M0FN = {
+      127, -127, 1, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes};
+
 static constexpr fltSemantics semFloat6E3M2FN = {
     4, -2, 3, 6, fltNonfiniteBehavior::FiniteOnly};
 static constexpr fltSemantics semFloat6E2M3FN = {
@@ -222,6 +225,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) {
     return Float8E3M4();
   case S_FloatTF32:
     return FloatTF32();
+  case S_Float8E8M0FN:
+    return Float8E8M0FN();
   case S_Float6E3M2FN:
     return Float6E3M2FN();
   case S_Float6E2M3FN:
@@ -264,6 +269,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
     return S_Float8E3M4;
   else if (&Sem == &llvm::APFloat::FloatTF32())
     return S_FloatTF32;
+  else if (&Sem == &llvm::APFloat::Float8E8M0FN())
+    return S_Float8E8M0FN;
   else if (&Sem == &llvm::APFloat::Float6E3M2FN())
     return S_Float6E3M2FN;
   else if (&Sem == &llvm::APFloat::Float6E2M3FN())
@@ -294,6 +301,7 @@ const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
 }
 const fltSemantics &APFloatBase::Float8E3M4() { return semFloat8E3M4; }
 const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; }
+const fltSemantics &APFloatBase::Float8E8M0FN() { return semFloat8E8M0FN; }
 const fltSemantics &APFloatBase::Float6E3M2FN() { return semFloat6E3M2FN; }
 const fltSemantics &APFloatBase::Float6E2M3FN() { return semFloat6E2M3FN; }
 const fltSemantics &APFloatBase::Float4E2M1FN() { return semFloat4E2M1FN; }
@@ -396,6 +404,8 @@ static inline Error createError(const Twine &Err) {
 }
 
 static constexpr inline unsigned int partCountForBits(unsigned int bits) {
+  if (bits == 0)
+    return 1;
   return ((bits) + APFloatBase::integerPartWidth - 1) / APFloatBase::integerPartWidth;
 }
 
@@ -955,6 +965,12 @@ void IEEEFloat::makeNaN(bool SNaN, bool Negative, const APInt *fill) {
       significand[part] = 0;
   }
 
+  // For the E8M0 types, precision is just 1 and the
+  // the NaNBit handling below is not relevant.
+  // So, exit early.
+  if (semantics == &semFloat8E8M0FN)
+    return;
+
   unsigned QNaNBit = semantics->precision - 2;
 
   if (SNaN) {
@@ -1007,6 +1023,10 @@ IEEEFloat &IEEEFloat::operator=(IEEEFloat &&rhs) {
 }
 
 bool IEEEFloat::isDenormal() const {
+  // No denormals in Float8E8M0FN
+  if (semantics == &semFloat8E8M0FN)
+    return false;
+
   return isFiniteNonZero() && (exponent == semantics->minExponent) &&
          (APInt::tcExtractBit(significandParts(),
                               semantics->precision - 1) == 0);
@@ -1028,6 +1048,10 @@ bool IEEEFloat::isSmallestNormalized() const {
 bool IEEEFloat::isSignificandAllOnes() const {
   // Test if the significand excluding the integral bit is all ones. This allows
   // us to test for binade boundaries.
+  // For the E8M0 format, this is always false since there are no
+  // actual significand bits.
+  if (semantics == &semFloat8E8M0FN)
+    return false;
   const integerPart *Parts = significandParts();
   const unsigned PartCount = partCountForBits(semantics->precision);
   for (unsigned i = 0; i < PartCount - 1; i++)
@@ -1075,6 +1099,11 @@ bool IEEEFloat::isSignificandAllOnesExceptLSB() const {
 }
 
 bool IEEEFloat::isSignificandAllZeros() const {
+  // For the E8M0 format, this is always true since there are no
+  // actual significand bits.
+  if (semantics == &semFloat8E8M0FN)
+    return true;
+
   // Test if the significand excluding the integral bit is all zeros. This
   // allows us to test for binade boundaries.
   const integerPart *Parts = significandParts();
@@ -1113,6 +1142,8 @@ bool IEEEFloat::isSignificandAllZerosExceptMSB() const {
 }
 
 bool IEEEFloat::isLargest() const {
+  if (semantics == &semFloat8E8M0FN)
+    return isFiniteNonZero() && exponent == semantics->maxExponent;
   if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly &&
       semantics->nanEncoding == fltNanEncoding::AllOnes) {
     // The largest number by magnitude in our format will be the floating point
@@ -1165,6 +1196,12 @@ IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics, integerPart value) {
 
 IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics) {
   initialize(&ourSemantics);
+  // The E8M0 type cannot represent the value zero.
+  // So, initialize with the closest representation instead.
+  if (semantics == &semFloat8E8M0FN) {
+    makeSmallestNormalized(false);
+    return;
+  }
   makeZero(false);
 }
 
@@ -1727,6 +1764,11 @@ IEEEFloat::opStatus IEEEFloat::normalize(roundingMode rounding_mode,
   /* Canonicalize zeroes.  */
   if (omsb == 0) {
     category = fcZero;
+    // The E8M0 type cannot represent the value zero and
+    // thus the category cannot be fcZero. So, get the
+    // closest representation to fcZero instead.
+    if (semantics == &semFloat8E8M0FN)
+      makeSmallestNormalized(false);
     if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
       sign = false;
   }
@@ -2606,6 +2648,11 @@ IEEEFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics,
     fs = opOK;
   }
 
+  // The E8M0 type cannot represent the value zero and
+  // thus the category cannot be fcZero. So, get the
+  // closest representation to fcZero instead.
+  if (category == fcZero && semantics == &semFloat8E8M0FN)
+    makeSmallestNormalized(false);
   return fs;
 }
 
@@ -3070,6 +3117,11 @@ IEEEFloat::convertFromDecimalString(StringRef str, roundingMode rounding_mode) {
     fs = opOK;
     if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
       sign = false;
+    // The E8M0 type cannot represent the value zero and
+    // thus the category cannot be fcZero. So, get the
+    // closest representation to fcZero instead.
+    if (semantics == &semFloat8E8M0FN)
+      makeSmallestNormalized(false);
 
     /* Check whether the normalized exponent is high enough to overflow
        max during the log-rebasing in the max-exponent check below. */
@@ -3533,15 +3585,15 @@ APInt IEEEFloat::convertPPCDoubleDoubleAPFloatToAPInt() const {
 template <const fltSemantics &S>
 APInt IEEEFloat::convertIEEEFloatToAPInt() const {
   assert(semantics == &S);
-
-  constexpr int bias = -(S.minExponent - 1);
+  const int bias = (semantics == &semFloat8E8M0FN) ?
+    -S.minExponent : -(S.minExponent - 1);
   constexpr unsigned int trailing_significand_bits = S.precision - 1;
   constexpr int integer_bit_part = trailing_significand_bits / integerPartWidth;
   constexpr integerPart integer_bit =
       integerPart{1} << (trailing_significand_bits % integerPartWidth);
   constexpr uint64_t significand_mask = integer_bit - 1;
-  constexpr unsigned int exponent_bits =
-      S.sizeInBits - 1 - trailing_significand_bits;
+  constexpr unsigned int exponent_bits = trailing_significand_bits ?
+    (S.sizeInBits - 1 - trailing_significand_bits) : S.sizeInBits;
   static_assert(exponent_bits < 64);
   constexpr uint64_t exponent_mask = (uint64_t{1} << exponent_bits) - 1;
 
@@ -3557,6 +3609,8 @@ APInt IEEEFloat::convertIEEEFloatToAPInt() const {
         !(significandParts()[integer_bit_part] & integer_bit))
       myexponent = 0; // denormal
   } else if (category == fcZero) {
+    if (semantics == &semFloat8E8M0FN)
+      llvm_unreachable("semantics does not support zero!");
     myexponent = ::exponentZero(S) + bias;
     mysignificand.fill(0);
   } else if (category == fcInfinity) {
@@ -3659,6 +3713,11 @@ APInt IEEEFloat::convertFloatTF32APFloatToAPInt() const {
   return convertIEEEFloatToAPInt<semFloatTF32>();
 }
 
+APInt IEEEFloat::convertFloat8E8M0FNAPFloatToAPInt() const {
+  assert(partCount() == 1);
+  return convertIEEEFloatToAPInt<semFloat8E8M0FN>();
+}
+
 APInt IEEEFloat::convertFloat6E3M2FNAPFloatToAPInt() const {
   assert(partCount() == 1);
   return convertIEEEFloatToAPInt<semFloat6E3M2FN>();
@@ -3721,6 +3780,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
   if (semantics == (const llvm::fltSemantics *)&semFloatTF32)
     return convertFloatTF32APFloatToAPInt();
 
+  if (semantics == (const llvm::fltSemantics *)&semFloat8E8M0FN)
+    return convertFloat8E8M0FNAPFloatToAPInt();
+
   if (semantics == (const llvm::fltSemantics *)&semFloat6E3M2FN)
     return convertFloat6E3M2FNAPFloatToAPInt();
 
@@ -3819,6 +3881,40 @@ void IEEEFloat::initFromPPCDoubleDoubleAPInt(const APInt &api) {
   }
 }
 
+// The E8M0 format has the following characteristics:
+// It is an 8-bit unsigned format with only exponents (no actual significand)
+// No encodings for {zero, infinities or denorms}
+// NaN is represented by all 1's
+// Bias is 127
+void IEEEFloat::initFromFloat8E8M0FNAPInt(const APInt &api) {
+  const uint64_t exponent_mask = 0xff;
+  uint64_t val = api.getRawData()[0];
+  uint64_t myexponent = (val & exponent_mask);
+
+  initialize(&semFloat8E8M0FN);
+  assert(partCount() == 1);
+
+  // This format has unsigned representation only
+  sign = 0;
+
+  // Set the significand
+  // This format does not have any significand but the 'Pth' precision bit is
+  // always set to 1 for consistency in APFloat's internal representation.
+  uint64_t mysignificand = 1;
+  significandParts()[0] = mysignificand;
+
+  // This format can either have a NaN or fcNormal
+  // All 1's i.e. 255 is a NaN
+  if (val == exponent_mask) {
+    category = fcNaN;
+    exponent = exponentNaN();
+    return;
+  }
+  // Handle fcNormal...
+  category = fcNormal;
+  exponent = myexponent - 127; // 127 is bias
+  return;
+}
 template <const fltSemantics &S>
 void IEEEFloat::initFromIEEEAPInt(const APInt &api) {
   assert(api.getBitWidth() == S.sizeInBits);
@@ -3999,6 +4095,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
     return initFromFloat8E3M4APInt(api);
   if (Sem == &semFloatTF32)
     return initFromFloatTF32APInt(api);
+  if (Sem == &semFloat8E8M0FN)
+    return initFromFloat8E8M0FNAPInt(api);
   if (Sem == &semFloat6E3M2FN)
     return initFromFloat6E3M2FNAPInt(api);
   if (Sem == &semFloat6E2M3FN)
@@ -4032,6 +4130,13 @@ void IEEEFloat::makeLargest(bool Negative) {
   significand[PartCount - 1] = (NumUnusedHighBits < integerPartWidth)
                                    ? (~integerPart(0) >> NumUnusedHighBits)
                                    : 0;
+  // For E8M0 format, we only have the 'internal' precision bit
+  // (aka 'P' the precision bit) which is always set to 1.
+  // Hence, the below logic of setting the LSB to 0 does not apply.
+  // For other cases, the LSB is meant to be any bit other than
+  // the Pth precision bit.
+  if (semantics == &semFloat8E8M0FN)
+    return;
 
   if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly &&
       semantics->nanEncoding == fltNanEncoding::AllOnes)
@@ -4509,6 +4614,11 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) {
       exponent = 0;
       if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
         sign = false;
+      // The E8M0 type cannot represent the value zero and
+      // thus the category cannot be fcZero. So, get the
+      // closest representation to fcZero instead.
+      if (semantics == &semFloat8E8M0FN)
+        makeSmallestNormalized(false);
       break;
     }
 
@@ -4575,6 +4685,11 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) {
       // denormal always increment since moving denormals and the numbers in the
       // smallest normal binade have the same exponent in our representation.
       bool WillCrossBinadeBoundary = !isDenormal() && isSignificandAllOnes();
+      // The E8M0 format does not support Denorms.
+      // Since there are only exponents, any increment always crosses the
+      // 'BinadeBoundary'. So, make this true always.
+      if (semantics == &semFloat8E8M0FN)
+        WillCrossBinadeBoundary = true;
 
       if (WillCrossBinadeBoundary) {
         integerPart *Parts = significandParts();
@@ -4626,6 +4741,11 @@ void IEEEFloat::makeInf(bool Negative) {
 }
 
 void IEEEFloat::makeZero(bool Negative) {
+  // The E8M0 type cannot represent the value zero.
+  if (semantics == &semFloat8E8M0FN) {
+    assert(false && "This floating point format does not support Zero\n");
+    return;
+  }
   category = fcZero;
   sign = Negative;
   if (semantics->nanEncoding == fltNanEncoding::NegativeZero) {
diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp
index be675bb7fe5a53..5cfe028ca2a3c6 100644
--- a/llvm/unittests/ADT/APFloatTest.cpp
+++ b/llvm/unittests/ADT/APFloatTest.cpp
@@ -814,8 +814,10 @@ TEST(APFloatTest, IsSmallestNormalized) {
     const fltSemantics &Semantics =
         APFloat::EnumToSemantics(static_cast<APFloat::Semantics>(I));
 
-    EXPECT_FALSE(APFloat::getZero(Semantics, false).isSmallestNormalized());
-    EXPECT_FALSE(APFloat::getZero(Semantics, true).isSmallestNormalized());
+    if (APFloat::hasZero(Semantics)) {
+      EXPECT_FALSE(APFloat::getZero(Semantics, false).isSmallestNormalized());
+      EXPECT_FALSE(APFloat::getZero(Semantics, true).isSmallestNormalized());
+    }
 
     if (APFloat::hasNanOrInf(Semantics)) {
       EXPECT_FALSE(APFloat::getInf(Semantics, false).isSmallestNormalized());
@@ -828,11 +830,22 @@ TEST(APFloatTest, IsSmallestNormalized) {
     EXPECT_FALSE(APFloat::getLargest(Semantics).isSmallestNormalized());
     EXPECT_FALSE(APFloat::getLargest(Semantics, true).isSmallestNormalized());
 
-    EXPECT_FALSE(APFloat::getSmallest(Semantics).isSmallestNormalized());
-    EXPECT_FALSE(APFloat::getSmallest(Semantics, true).isSmallestNormalized());
+    if (I != APFloat::S_Float8E8M0FN) {
+      EXPECT_FALSE(APFloat::getSmallest(Semantics).isSmallestNormalized());
+      EXPECT_FALSE(APFloat::getSmallest(Semantics, true).isSmallestNormalized());
+    } else {
+      // For E8M0 format, there are no denorms.
+      // So, getSmallest is equal to isSmallestNormalized().
+      EXPECT_TRUE(APFloat::getSmallest(Semantics).isSmallestNormalized());
+      EXPECT_TRUE(APFloat::getSmallest(Semantics, true).isSmallestNormalized());
+    }
 
     EXPECT_FALSE(APFloat::getAllOnesValue(Semantics).isSmallestNormalized());
 
+    // For E8M0 format, the below cases are tested through Float8E8M0FNNext.
+    if (I == APFloat::S_Float8E8M0FN)
+      continue;
+
     APFloat PosSmallestNormalized =
         APFloat::getSmallestNormalized(Semantics, false);
     APFloat NegSmallestNormalized =
@@ -1907,6 +1920,57 @@ TEST(DoubleAPFloatTest, isInteger) {
   EXPECT_FALSE(T3.isInteger());
 }
 
+// Test to check if the full range of Float8E8M0FN
+// values are being represented correctly.
+TEST(APFloatTest, Float8E8M0FNValues) {
+  // High end of the range
+  auto test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p127");
+  EXPECT_EQ(0x1.0p127, test.convertToDouble());
+
+  test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p126");
+  EXPECT_EQ(0x1.0p126, test.convertToDouble());
+
+  test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p125");
+  EXPECT_EQ(0x1.0p125, test.convertToDouble());
+
+  // tests the fix in makeLargest()
+  test = APFloat::getLargest(APFloat::Float8E8M0FN());
+  EXPECT_EQ(0x1.0p127, test.convertToDouble());
+
+  // tests overflow to nan
+  APFloat nan = APFloat(APFloat::Float8E8M0FN(), "nan");
+  test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p128");
+  EXPECT_TRUE(test.bitwiseIsEqual(nan));
+
+  // Mid of the range
+  test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p0");
+  EXPECT_EQ(1.0, test.convertToDouble());
+
+  test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p1");
+  EXPECT_EQ(2.0, test.convertToDouble());
+
+  test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p2");
+  EXPECT_EQ(4.0, test.convertToDouble());
+
+  // Low end of the range
+  test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p-125");
+  EXPECT_EQ(0x1.0p-125, test.convertToDouble());
+
+  test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p-126");
+  EXPECT_EQ(0x1.0p-126, test.convertToDouble());
+
+  test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p-127");
+  EXPECT_EQ(0x1.0p-127, test.convertToDouble());
+
+  // Smallest value
+  test = APFloat::getSmallest(APFloat::Float8E8M0FN());
+  EXPECT_EQ(0x1.0p-127, test.convertToDouble());
+
+  // Value below the smallest, but clamped to the smallest
+  test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p-128");
+  EXPECT_EQ(0x1.0p-127, test.convertToDouble());
+}
+
 TEST(APFloatTest, getLargest) {
   EXPECT_EQ(3.402823466e+38f, APFloat::getLargest(APFloat::IEEEsingle()).convertToFloat());
   EXPECT_EQ(1.7976931348623158e+308, APFloat::getLargest(APFloat::IEEEdouble()).convertToDouble());
@@ -1919,6 +1983,8 @@ TEST(APFloatTest, getLargest) {
       30, APFloat::getLargest(APFloat::Float8E4M3B11FNUZ()).convertToDouble());
   EXPECT_EQ(3.40116213421e+38f,
             APFloat::getLargest(APFloat::FloatTF32()).convertToFloat());
+  EXPECT_EQ(1.701411834e+38f,
+            APFloat::getLargest(APFloat::Float8E8M0FN()).convertToDouble());
   EXPECT_EQ(28, APFloat::getLargest(APFloat::Float6E3M2FN()).convertToDouble());
   EXPECT_EQ(7.5,
  ...
[truncated]

Copy link

github-actions bot commented Sep 3, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@tschuett tschuett added the floating-point Floating-point math label Sep 3, 2024
@tschuett tschuett requested a review from arsenm September 3, 2024 19:47
llvm/lib/Support/APFloat.cpp Show resolved Hide resolved
llvm/lib/Support/APFloat.cpp Outdated Show resolved Hide resolved
llvm/unittests/ADT/APFloatTest.cpp Outdated Show resolved Hide resolved
llvm/unittests/ADT/APFloatTest.cpp Outdated Show resolved Hide resolved
llvm/include/llvm/ADT/APFloat.h Outdated Show resolved Hide resolved
llvm/lib/Support/APFloat.cpp Outdated Show resolved Hide resolved
llvm/lib/Support/APFloat.cpp Outdated Show resolved Hide resolved
llvm/lib/Support/APFloat.cpp Outdated Show resolved Hide resolved
llvm/lib/Support/APFloat.cpp Outdated Show resolved Hide resolved
llvm/include/llvm/ADT/APFloat.h Outdated Show resolved Hide resolved
llvm/lib/Support/APFloat.cpp Outdated Show resolved Hide resolved
llvm/lib/Support/APFloat.cpp Outdated Show resolved Hide resolved
llvm/lib/Support/APFloat.cpp Outdated Show resolved Hide resolved
@sergey-kozub
Copy link
Contributor

Note that the Numpy E8M0 dtype (in the JAX repo) is named float8_e8m0_fnu
PR link: jax-ml/ml_dtypes#166

Should we also add the U suffix (which stands for "unsigned") to the LLVM type name, for consistency?

@durga4github
Copy link
Contributor Author

Sure, all for consistency! I was not aware of the jax side of things.
I will refresh and update the name to: "Float8E8M0FNU".

@durga4github
Copy link
Contributor Author

Type name updated to: "Float8E8M0FNU" + Refactored a few tests for better readability.

llvm/lib/Support/APFloat.cpp Outdated Show resolved Hide resolved
@@ -1165,7 +1200,7 @@ IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics, integerPart value) {

IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics) {
initialize(&ourSemantics);
makeZero(false);
ourSemantics.hasZero ? makeZero(false) : makeSmallestNormalized(false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not obvious to me that makeSmallestNormalized is the alternative. What's wrong with makeZero? You'll get a zero exponent, and an empty significand?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Float8E8M0FNU format does not have a representation for zero. So, we are initializing with the closest possible representation (which happens to be the smallestNormalized for this format).

Zero is special with its own 'fcZero' category. Since we have a normal value here, we cannot use makeZero() to initialize it. (i.e. we need this to be categorized as 'fcNormal' and not as 'fcZero').

I have added a comment regarding this in the latest revision.

llvm/lib/Support/APFloat.cpp Outdated Show resolved Hide resolved
This patch adds an APFloat type for unsigned E8M0 format.
This format is used for representing the "scale-format"
in the MX specification:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

This format does not support {Inf, denorms, zeroes}.
Like FP32, this format's exponents are 8-bits (all bits here)
and the bias value is 127. However, it differs from IEEE-FP32
in that the minExponent is -127 (instead of -126).
There are updates done in the APFloat utility functions
to handle these constraints for this format.

* The bias calculation is different and convertIEEE* APIs
  are updated to handle this.
* Since there are no significand bits, the
  isSignificandAll{Zeroes/Ones} methods are updated accordingly.
* Although the format does not have any precision, the precision
  bit in the fltSemantics is set to 1 for consistency with
  APFloat's internal representation.
* Many utility functions are updated to handle the fact that this
  format does not support Zero.
* Provide a separate initFromAPInt() implementation to
  handle the quirks of the format.
* Add specific tests to verify the range of values for this format.

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants