Skip to content

Commit

Permalink
Remove more initialization by const-ref
Browse files Browse the repository at this point in the history
While we are here, implement double/long double -> float8 conversions via conversion to float. This lets us avoid doing arithmetic using 64-bit types during the float8 rounding step.

PiperOrigin-RevId: 577251071
  • Loading branch information
majnemer authored and The ml_dtypes Authors committed Oct 27, 2023
1 parent 161db24 commit 720823f
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 13 deletions.
50 changes: 38 additions & 12 deletions ml_dtypes/include/float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ class float8_base {

template <typename T>
explicit EIGEN_DEVICE_FUNC float8_base(
T f, std::enable_if_t<std::is_arithmetic_v<T>, int> = 0)
: float8_base(ConvertFrom(static_cast<float>(f)).rep(),
T i, std::enable_if_t<std::is_integral_v<T>, int> = 0)
: float8_base(ConvertFrom(static_cast<float>(i)).rep(),
ConstructFromRepTag{}) {}
explicit EIGEN_DEVICE_FUNC float8_base(double f64)
: float8_base(ConvertFrom(f64).rep(), ConstructFromRepTag{}) {}
explicit EIGEN_DEVICE_FUNC float8_base(float f32)
: float8_base(ConvertFrom(f32).rep(), ConstructFromRepTag{}) {}
template <typename T>
explicit EIGEN_DEVICE_FUNC float8_base(
T f, std::enable_if_t<std::is_floating_point_v<T>, int> = 0)
: float8_base(ConvertFrom(f).rep(), ConstructFromRepTag{}) {}
explicit EIGEN_DEVICE_FUNC float8_base(Eigen::bfloat16 bf16)
: float8_base(ConvertFrom(bf16).rep(), ConstructFromRepTag{}) {}
explicit EIGEN_DEVICE_FUNC float8_base(Eigen::half f16)
Expand Down Expand Up @@ -112,10 +112,10 @@ class float8_base {

// Conversions allowing saturation and truncation.
template <bool kSaturate = false, bool kTruncate = false, typename From>
static inline EIGEN_DEVICE_FUNC Derived ConvertFrom(const From& from);
static inline EIGEN_DEVICE_FUNC Derived ConvertFrom(From from);

template <typename To, bool kSaturate = false, bool kTruncate = false>
static inline EIGEN_DEVICE_FUNC To ConvertTo(const Derived& from);
static inline EIGEN_DEVICE_FUNC To ConvertTo(Derived from);

// Operators via float32.
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived
Expand Down Expand Up @@ -634,7 +634,8 @@ struct numeric_limits_float8_e4m3fnuz : public numeric_limits_float8_base {
}
static constexpr float8_e4m3fnuz infinity() {
return float8_e4m3fnuz::FromRep(0x80);
} // NaN.
}
// NaN.
static constexpr float8_e4m3fnuz quiet_NaN() {
return float8_e4m3fnuz::FromRep(0x80);
}
Expand Down Expand Up @@ -1239,13 +1240,38 @@ struct ConvertImpl<float8_e5m2, Eigen::half, kSaturate, kTruncate> {

template <typename Derived>
template <bool kSaturate, bool kTruncate, typename From>
EIGEN_DEVICE_FUNC Derived float8_base<Derived>::ConvertFrom(const From& from) {
return ConvertImpl<From, Derived, kSaturate, kTruncate>::run(from);
EIGEN_DEVICE_FUNC Derived float8_base<Derived>::ConvertFrom(const From from) {
// We are rounding double/long double -> float -> float8. This can induce
// double-rounding which may alter the results. We can correct for this using
// a trick explained in: Boldo, Sylvie, and Guillaume Melquiond. "When double
// rounding is odd." 17th IMACS World Congress. 2005.
if constexpr (std::is_floating_point_v<From> &&
sizeof(From) > sizeof(float)) {
// binary64, float80, binary128, etc. end up here.
static_assert(std::numeric_limits<From>::digits >=
std::numeric_limits<float>::digits + 2);
static_assert(std::numeric_limits<float>::min_exponent >=
std::numeric_limits<From>::min_exponent + 2);
static_assert(std::numeric_limits<float>::radix == 2);
float from_rnd_float = static_cast<float>(from);

// Round-to-odd involves us setting the LSB if we dropped any bits while
// rounding.
if (std::isfinite(from_rnd_float) &&
static_cast<From>(from_rnd_float) != from) {
from_rnd_float = Eigen::numext::bit_cast<float>(
Eigen::numext::bit_cast<uint32_t>(from_rnd_float) | 1);
}
return ConvertImpl<float, Derived, kSaturate, kTruncate>::run(
from_rnd_float);
} else {
return ConvertImpl<From, Derived, kSaturate, kTruncate>::run(from);
}
}

template <typename Derived>
template <typename To, bool kSaturate, bool kTruncate>
EIGEN_DEVICE_FUNC To float8_base<Derived>::ConvertTo(const Derived& from) {
EIGEN_DEVICE_FUNC To float8_base<Derived>::ConvertTo(const Derived from) {
return ConvertImpl<Derived, To, kSaturate, kTruncate>::run(from);
}

Expand Down
36 changes: 35 additions & 1 deletion ml_dtypes/tests/float8_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,40 @@ TYPED_TEST(Float8Test, ConvertTo) {
}
}

template <typename SrcType, typename IntermediateType, typename Float8>
static SrcType DoubleRoundHelper() {
// If we have a number of the form 1.0..010..010.., two rounds of RTNE can
// cause the last-set bit to get rounded down due to RTNE which in turn will
// cause the other bit to get rounded down due to RTNE. RTNE's tie breaking
// semantics *should* not apply here as there is no tie but double-rounding
// may confuse us.
SrcType x{1.0};
x += std::ldexp(SrcType{1.0}, -std::numeric_limits<Float8>::digits);
x += std::ldexp(SrcType{1.0}, -std::numeric_limits<IntermediateType>::digits);
auto rounded_x = static_cast<Float8>(x);
return static_cast<SrcType>(rounded_x);
}

// This test tries to capture mistakes in `float8_base::ConverFrom` where it is
// implemented by a series of conversions. e.g. converting a double to a float
// to a float8 introduces double-rounding which makes the final rounding step
// unfaithful. Craft a variety of numbers which try to detect if this happens.
TYPED_TEST(Float8Test, DoubleRound) {
using Float8 = TypeParam;

// We expect that our number results in rounding up to the number after 1.
// Incorrect rounding will result in 1.
const double expected =
1.0 + static_cast<double>(std::numeric_limits<Float8>::epsilon());

// Don't use long double on targets which don't support it.
#if !defined(EIGEN_USE_GPU) && !defined(EIGEN_GPU_COMPILE_PHASE)
EXPECT_EQ((DoubleRoundHelper<long double, double, Float8>()), expected);
EXPECT_EQ((DoubleRoundHelper<long double, float, Float8>()), expected);
#endif
EXPECT_EQ((DoubleRoundHelper<double, float, Float8>()), expected);
}

TEST(Float8Test, Float8E5m2_To_Float8E4m3) {
// Saturation.
float8_e5m2 max = std::numeric_limits<float8_e5m2>::max();
Expand Down Expand Up @@ -677,7 +711,7 @@ TYPED_TEST(Float8Test, CallTheConstOperator) {
}
}

TEST(Float855m2Test, SmallCastToDenormal) {
TEST(Float8E5m2Test, SmallCastToDenormal) {
// Special edge-case where rounding to a normalized value would
// normally round down, but rounding to a subnormal rounds up.
float x = std::ldexp(1.3125, -15);
Expand Down

0 comments on commit 720823f

Please sign in to comment.