Skip to content

Commit

Permalink
[APFloat] Add APFloat support for E8M0 type
Browse files Browse the repository at this point in the history
This patch adds an APFloat type for unsigned E8M0 format.
This format is used for representing the "scale-format"
in the MX specification:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

This format does not support {Inf, denorms, zeroes}.
Like FP32, this format's exponents are 8-bits (all bits here)
and the bias value is 127. However, it differs from IEEE-FP32
in that the minExponent is -127 (instead of -126).
There are updates done in the APFloat utility functions
to handle these constraints for this format.

* The bias calculation is different and convertIEEE* APIs
  are updated to handle this.
* Since there are no significand bits, the
  isSignificandAll{Zeroes/Ones} methods are updated accordingly.
* Although the format does not have any precision, the precision
  bit in the fltSemantics is set to 1 for consistency with
  APFloat's internal representation.
* Many utility functions are updated to handle the fact that this
  format does not support Zero.
* Provide a separate initFromAPInt() implementation to
  handle the quirks of the format.
* Add specific tests to verify the range of values for this format.

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
  • Loading branch information
durga4github committed Sep 4, 2024
1 parent a8e1c6f commit 26f3c85
Show file tree
Hide file tree
Showing 3 changed files with 381 additions and 21 deletions.
37 changes: 37 additions & 0 deletions llvm/include/llvm/ADT/APFloat.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ struct APFloatBase {
// improved range compared to half (16-bit) formats, at (potentially)
// greater throughput than single precision (32-bit) formats.
S_FloatTF32,
// 8-bit floating point number with (all the) 8 bits for the exponent
// like in FP32. There are no zeroes, no infinities, and no denormal values.
// NaN is represented with all bits set to 1. Bias is 127.
// This represents the scale data type in the MX specification from
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
S_Float8E8M0FN,
// 6-bit floating point number with bit layout S1E3M2. Unlike IEEE-754
// types, there are no infinity or NaN values. The format is detailed in
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
Expand Down Expand Up @@ -229,6 +235,7 @@ struct APFloatBase {
static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE;
static const fltSemantics &Float8E3M4() LLVM_READNONE;
static const fltSemantics &FloatTF32() LLVM_READNONE;
static const fltSemantics &Float8E8M0FN() LLVM_READNONE;
static const fltSemantics &Float6E3M2FN() LLVM_READNONE;
static const fltSemantics &Float6E2M3FN() LLVM_READNONE;
static const fltSemantics &Float4E2M1FN() LLVM_READNONE;
Expand Down Expand Up @@ -652,6 +659,7 @@ class IEEEFloat final : public APFloatBase {
APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const;
APInt convertFloat8E3M4APFloatToAPInt() const;
APInt convertFloatTF32APFloatToAPInt() const;
APInt convertFloat8E8M0FNAPFloatToAPInt() const;
APInt convertFloat6E3M2FNAPFloatToAPInt() const;
APInt convertFloat6E2M3FNAPFloatToAPInt() const;
APInt convertFloat4E2M1FNAPFloatToAPInt() const;
Expand All @@ -672,6 +680,7 @@ class IEEEFloat final : public APFloatBase {
void initFromFloat8E4M3B11FNUZAPInt(const APInt &api);
void initFromFloat8E3M4APInt(const APInt &api);
void initFromFloatTF32APInt(const APInt &api);
void initFromFloat8E8M0FNAPInt(const APInt &api);
void initFromFloat6E3M2FNAPInt(const APInt &api);
void initFromFloat6E2M3FNAPInt(const APInt &api);
void initFromFloat4E2M1FNAPInt(const APInt &api);
Expand Down Expand Up @@ -1079,6 +1088,9 @@ class APFloat : public APFloatBase {
/// \param Semantics - type float semantics
static APFloat getAllOnesValue(const fltSemantics &Semantics);

/// Returns true if the given semantics supports either NaN or Infinity.
///
/// \param Sem - type float semantics
static bool hasNanOrInf(const fltSemantics &Sem) {
switch (SemanticsToEnum(Sem)) {
default:
Expand All @@ -1091,6 +1103,31 @@ class APFloat : public APFloatBase {
}
}

/// Returns true if the given semantics can represent Zero.
///
/// \param Sem - type float semantics
static bool hasZero(const fltSemantics &Sem) {
switch (SemanticsToEnum(Sem)) {
default:
return true;
case APFloat::S_Float8E8M0FN:
return false;
}
}

/// Returns true if the given semantics has only exponent
/// and no significand.
///
/// \param Sem - type float semantics
static bool hasExponentOnly(const fltSemantics &Sem) {
switch (SemanticsToEnum(Sem)) {
default:
return false;
case APFloat::S_Float8E8M0FN:
return true;
}
}

/// Used to insert APFloat objects, or objects that contain APFloat objects,
/// into FoldingSets.
void Profile(FoldingSetNodeID &NID) const;
Expand Down
126 changes: 123 additions & 3 deletions llvm/lib/Support/APFloat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ static constexpr fltSemantics semFloat8E4M3B11FNUZ = {
4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E3M4 = {3, -2, 5, 8};
static constexpr fltSemantics semFloatTF32 = {127, -126, 11, 19};
static constexpr fltSemantics semFloat8E8M0FN = {
127, -127, 1, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes};

static constexpr fltSemantics semFloat6E3M2FN = {
4, -2, 3, 6, fltNonfiniteBehavior::FiniteOnly};
static constexpr fltSemantics semFloat6E2M3FN = {
Expand Down Expand Up @@ -222,6 +225,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) {
return Float8E3M4();
case S_FloatTF32:
return FloatTF32();
case S_Float8E8M0FN:
return Float8E8M0FN();
case S_Float6E3M2FN:
return Float6E3M2FN();
case S_Float6E2M3FN:
Expand Down Expand Up @@ -264,6 +269,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
return S_Float8E3M4;
else if (&Sem == &llvm::APFloat::FloatTF32())
return S_FloatTF32;
else if (&Sem == &llvm::APFloat::Float8E8M0FN())
return S_Float8E8M0FN;
else if (&Sem == &llvm::APFloat::Float6E3M2FN())
return S_Float6E3M2FN;
else if (&Sem == &llvm::APFloat::Float6E2M3FN())
Expand Down Expand Up @@ -294,6 +301,7 @@ const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
}
const fltSemantics &APFloatBase::Float8E3M4() { return semFloat8E3M4; }
const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; }
const fltSemantics &APFloatBase::Float8E8M0FN() { return semFloat8E8M0FN; }
const fltSemantics &APFloatBase::Float6E3M2FN() { return semFloat6E3M2FN; }
const fltSemantics &APFloatBase::Float6E2M3FN() { return semFloat6E2M3FN; }
const fltSemantics &APFloatBase::Float4E2M1FN() { return semFloat4E2M1FN; }
Expand Down Expand Up @@ -396,6 +404,8 @@ static inline Error createError(const Twine &Err) {
}

static constexpr inline unsigned int partCountForBits(unsigned int bits) {
if (bits == 0)
return 1;
return ((bits) + APFloatBase::integerPartWidth - 1) / APFloatBase::integerPartWidth;
}

Expand Down Expand Up @@ -955,6 +965,12 @@ void IEEEFloat::makeNaN(bool SNaN, bool Negative, const APInt *fill) {
significand[part] = 0;
}

// For the E8M0 types, precision is just 1 and the
// the NaNBit handling below is not relevant.
// So, exit early.
if (semantics == &semFloat8E8M0FN)
return;

unsigned QNaNBit = semantics->precision - 2;

if (SNaN) {
Expand Down Expand Up @@ -1007,6 +1023,10 @@ IEEEFloat &IEEEFloat::operator=(IEEEFloat &&rhs) {
}

bool IEEEFloat::isDenormal() const {
// The E8M0 format does not support denormals.
if (semantics == &semFloat8E8M0FN)
return false;

return isFiniteNonZero() && (exponent == semantics->minExponent) &&
(APInt::tcExtractBit(significandParts(),
semantics->precision - 1) == 0);
Expand All @@ -1028,6 +1048,10 @@ bool IEEEFloat::isSmallestNormalized() const {
bool IEEEFloat::isSignificandAllOnes() const {
// Test if the significand excluding the integral bit is all ones. This allows
// us to test for binade boundaries.
// For the E8M0 format, this is always false since there are no
// actual significand bits.
if (semantics == &semFloat8E8M0FN)
return false;
const integerPart *Parts = significandParts();
const unsigned PartCount = partCountForBits(semantics->precision);
for (unsigned i = 0; i < PartCount - 1; i++)
Expand Down Expand Up @@ -1075,6 +1099,11 @@ bool IEEEFloat::isSignificandAllOnesExceptLSB() const {
}

bool IEEEFloat::isSignificandAllZeros() const {
// For the E8M0 format, this is always true since there are no
// actual significand bits.
if (semantics == &semFloat8E8M0FN)
return true;

// Test if the significand excluding the integral bit is all zeros. This
// allows us to test for binade boundaries.
const integerPart *Parts = significandParts();
Expand Down Expand Up @@ -1113,6 +1142,8 @@ bool IEEEFloat::isSignificandAllZerosExceptMSB() const {
}

bool IEEEFloat::isLargest() const {
if (semantics == &semFloat8E8M0FN)
return isFiniteNonZero() && exponent == semantics->maxExponent;
if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly &&
semantics->nanEncoding == fltNanEncoding::AllOnes) {
// The largest number by magnitude in our format will be the floating point
Expand Down Expand Up @@ -1165,6 +1196,12 @@ IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics, integerPart value) {

IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics) {
initialize(&ourSemantics);
// The E8M0 type cannot represent the value zero.
// So, initialize with the closest representation instead.
if (semantics == &semFloat8E8M0FN) {
makeSmallestNormalized(false);
return;
}
makeZero(false);
}

Expand Down Expand Up @@ -1727,6 +1764,11 @@ IEEEFloat::opStatus IEEEFloat::normalize(roundingMode rounding_mode,
/* Canonicalize zeroes. */
if (omsb == 0) {
category = fcZero;
// The E8M0 type cannot represent the value zero and
// thus the category cannot be fcZero. So, get the
// closest representation to fcZero instead.
if (semantics == &semFloat8E8M0FN)
makeSmallestNormalized(false);
if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
sign = false;
}
Expand Down Expand Up @@ -2606,6 +2648,11 @@ IEEEFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics,
fs = opOK;
}

// The E8M0 type cannot represent the value zero and
// thus the category cannot be fcZero. So, get the
// closest representation to fcZero instead.
if (category == fcZero && semantics == &semFloat8E8M0FN)
makeSmallestNormalized(false);
return fs;
}

Expand Down Expand Up @@ -3070,6 +3117,11 @@ IEEEFloat::convertFromDecimalString(StringRef str, roundingMode rounding_mode) {
fs = opOK;
if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
sign = false;
// The E8M0 type cannot represent the value zero and
// thus the category cannot be fcZero. So, get the
// closest representation to fcZero instead.
if (semantics == &semFloat8E8M0FN)
makeSmallestNormalized(false);

/* Check whether the normalized exponent is high enough to overflow
max during the log-rebasing in the max-exponent check below. */
Expand Down Expand Up @@ -3533,15 +3585,16 @@ APInt IEEEFloat::convertPPCDoubleDoubleAPFloatToAPInt() const {
template <const fltSemantics &S>
APInt IEEEFloat::convertIEEEFloatToAPInt() const {
assert(semantics == &S);

constexpr int bias = -(S.minExponent - 1);
const int bias =
(semantics == &semFloat8E8M0FN) ? -S.minExponent : -(S.minExponent - 1);
constexpr unsigned int trailing_significand_bits = S.precision - 1;
constexpr int integer_bit_part = trailing_significand_bits / integerPartWidth;
constexpr integerPart integer_bit =
integerPart{1} << (trailing_significand_bits % integerPartWidth);
constexpr uint64_t significand_mask = integer_bit - 1;
constexpr unsigned int exponent_bits =
S.sizeInBits - 1 - trailing_significand_bits;
trailing_significand_bits ? (S.sizeInBits - 1 - trailing_significand_bits)
: S.sizeInBits;
static_assert(exponent_bits < 64);
constexpr uint64_t exponent_mask = (uint64_t{1} << exponent_bits) - 1;

Expand All @@ -3557,6 +3610,8 @@ APInt IEEEFloat::convertIEEEFloatToAPInt() const {
!(significandParts()[integer_bit_part] & integer_bit))
myexponent = 0; // denormal
} else if (category == fcZero) {
if (semantics == &semFloat8E8M0FN)
llvm_unreachable("semantics does not support zero!");
myexponent = ::exponentZero(S) + bias;
mysignificand.fill(0);
} else if (category == fcInfinity) {
Expand Down Expand Up @@ -3659,6 +3714,11 @@ APInt IEEEFloat::convertFloatTF32APFloatToAPInt() const {
return convertIEEEFloatToAPInt<semFloatTF32>();
}

APInt IEEEFloat::convertFloat8E8M0FNAPFloatToAPInt() const {
assert(partCount() == 1);
return convertIEEEFloatToAPInt<semFloat8E8M0FN>();
}

APInt IEEEFloat::convertFloat6E3M2FNAPFloatToAPInt() const {
assert(partCount() == 1);
return convertIEEEFloatToAPInt<semFloat6E3M2FN>();
Expand Down Expand Up @@ -3721,6 +3781,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
if (semantics == (const llvm::fltSemantics *)&semFloatTF32)
return convertFloatTF32APFloatToAPInt();

if (semantics == (const llvm::fltSemantics *)&semFloat8E8M0FN)
return convertFloat8E8M0FNAPFloatToAPInt();

if (semantics == (const llvm::fltSemantics *)&semFloat6E3M2FN)
return convertFloat6E3M2FNAPFloatToAPInt();

Expand Down Expand Up @@ -3819,6 +3882,40 @@ void IEEEFloat::initFromPPCDoubleDoubleAPInt(const APInt &api) {
}
}

// The E8M0 format has the following characteristics:
// It is an 8-bit unsigned format with only exponents (no actual significand)
// No encodings for {zero, infinities or denorms}
// NaN is represented by all 1's
// Bias is 127
void IEEEFloat::initFromFloat8E8M0FNAPInt(const APInt &api) {
const uint64_t exponent_mask = 0xff;
uint64_t val = api.getRawData()[0];
uint64_t myexponent = (val & exponent_mask);

initialize(&semFloat8E8M0FN);
assert(partCount() == 1);

// This format has unsigned representation only
sign = 0;

// Set the significand
// This format does not have any significand but the 'Pth' precision bit is
// always set to 1 for consistency in APFloat's internal representation.
uint64_t mysignificand = 1;
significandParts()[0] = mysignificand;

// This format can either have a NaN or fcNormal
// All 1's i.e. 255 is a NaN
if (val == exponent_mask) {
category = fcNaN;
exponent = exponentNaN();
return;
}
// Handle fcNormal...
category = fcNormal;
exponent = myexponent - 127; // 127 is bias
return;
}
template <const fltSemantics &S>
void IEEEFloat::initFromIEEEAPInt(const APInt &api) {
assert(api.getBitWidth() == S.sizeInBits);
Expand Down Expand Up @@ -3999,6 +4096,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
return initFromFloat8E3M4APInt(api);
if (Sem == &semFloatTF32)
return initFromFloatTF32APInt(api);
if (Sem == &semFloat8E8M0FN)
return initFromFloat8E8M0FNAPInt(api);
if (Sem == &semFloat6E3M2FN)
return initFromFloat6E3M2FNAPInt(api);
if (Sem == &semFloat6E2M3FN)
Expand Down Expand Up @@ -4032,6 +4131,13 @@ void IEEEFloat::makeLargest(bool Negative) {
significand[PartCount - 1] = (NumUnusedHighBits < integerPartWidth)
? (~integerPart(0) >> NumUnusedHighBits)
: 0;
// For E8M0 format, we only have the 'internal' precision bit
// (aka 'P' the precision bit) which is always set to 1.
// Hence, the below logic of setting the LSB to 0 does not apply.
// For other cases, the LSB is meant to be any bit other than
// the Pth precision bit.
if (semantics == &semFloat8E8M0FN)
return;

if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly &&
semantics->nanEncoding == fltNanEncoding::AllOnes)
Expand Down Expand Up @@ -4509,6 +4615,11 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) {
exponent = 0;
if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
sign = false;
// The E8M0 type cannot represent the value zero and
// thus the category cannot be fcZero. So, get the
// closest representation to fcZero instead.
if (semantics == &semFloat8E8M0FN)
makeSmallestNormalized(false);
break;
}

Expand Down Expand Up @@ -4575,6 +4686,11 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) {
// denormal always increment since moving denormals and the numbers in the
// smallest normal binade have the same exponent in our representation.
bool WillCrossBinadeBoundary = !isDenormal() && isSignificandAllOnes();
// The E8M0 format does not support Denorms.
// Since there are only exponents, any increment always crosses the
// 'BinadeBoundary'. So, make this true always.
if (semantics == &semFloat8E8M0FN)
WillCrossBinadeBoundary = true;

if (WillCrossBinadeBoundary) {
integerPart *Parts = significandParts();
Expand Down Expand Up @@ -4626,6 +4742,10 @@ void IEEEFloat::makeInf(bool Negative) {
}

void IEEEFloat::makeZero(bool Negative) {
// The E8M0 type cannot represent the value zero.
if (semantics == &semFloat8E8M0FN)
llvm_unreachable("This floating point format does not support Zero");

category = fcZero;
sign = Negative;
if (semantics->nanEncoding == fltNanEncoding::NegativeZero) {
Expand Down
Loading

0 comments on commit 26f3c85

Please sign in to comment.