diff --git a/ml_dtypes/include/float8.h b/ml_dtypes/include/float8.h index e073b84f..ad286861 100644 --- a/ml_dtypes/include/float8.h +++ b/ml_dtypes/include/float8.h @@ -61,13 +61,13 @@ class float8_base { template explicit EIGEN_DEVICE_FUNC float8_base( - T f, std::enable_if_t, int> = 0) - : float8_base(ConvertFrom(static_cast(f)).rep(), + T i, std::enable_if_t, int> = 0) + : float8_base(ConvertFrom(static_cast(i)).rep(), ConstructFromRepTag{}) {} - explicit EIGEN_DEVICE_FUNC float8_base(double f64) - : float8_base(ConvertFrom(f64).rep(), ConstructFromRepTag{}) {} - explicit EIGEN_DEVICE_FUNC float8_base(float f32) - : float8_base(ConvertFrom(f32).rep(), ConstructFromRepTag{}) {} + template + explicit EIGEN_DEVICE_FUNC float8_base( + T f, std::enable_if_t, int> = 0) + : float8_base(ConvertFrom(f).rep(), ConstructFromRepTag{}) {} explicit EIGEN_DEVICE_FUNC float8_base(Eigen::bfloat16 bf16) : float8_base(ConvertFrom(bf16).rep(), ConstructFromRepTag{}) {} explicit EIGEN_DEVICE_FUNC float8_base(Eigen::half f16) @@ -112,10 +112,10 @@ class float8_base { // Conversions allowing saturation and truncation. template - static inline EIGEN_DEVICE_FUNC Derived ConvertFrom(const From& from); + static inline EIGEN_DEVICE_FUNC Derived ConvertFrom(From from); template - static inline EIGEN_DEVICE_FUNC To ConvertTo(const Derived& from); + static inline EIGEN_DEVICE_FUNC To ConvertTo(Derived from); // Operators via float32. EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived @@ -634,7 +634,8 @@ struct numeric_limits_float8_e4m3fnuz : public numeric_limits_float8_base { } static constexpr float8_e4m3fnuz infinity() { return float8_e4m3fnuz::FromRep(0x80); - } // NaN. + } + // NaN. static constexpr float8_e4m3fnuz quiet_NaN() { return float8_e4m3fnuz::FromRep(0x80); } @@ -1239,13 +1240,38 @@ struct ConvertImpl { template template -EIGEN_DEVICE_FUNC Derived float8_base::ConvertFrom(const From& from) { - return ConvertImpl::run(from); +EIGEN_DEVICE_FUNC Derived float8_base::ConvertFrom(const From from) { + // We are rounding double/long double -> float -> float8. This can induce + // double-rounding which may alter the results. We can correct for this using + // a trick explained in: Boldo, Sylvie, and Guillaume Melquiond. "When double + // rounding is odd." 17th IMACS World Congress. 2005. + if constexpr (std::is_floating_point_v && + sizeof(From) > sizeof(float)) { + // binary64, float80, binary128, etc. end up here. + static_assert(std::numeric_limits::digits >= + std::numeric_limits::digits + 2); + static_assert(std::numeric_limits::min_exponent >= + std::numeric_limits::min_exponent + 2); + static_assert(std::numeric_limits::radix == 2); + float from_rnd_float = static_cast(from); + + // Round-to-odd involves us setting the LSB if we dropped any bits while + // rounding. + if (std::isfinite(from_rnd_float) && + static_cast(from_rnd_float) != from) { + from_rnd_float = Eigen::numext::bit_cast( + Eigen::numext::bit_cast(from_rnd_float) | 1); + } + return ConvertImpl::run( + from_rnd_float); + } else { + return ConvertImpl::run(from); + } } template template -EIGEN_DEVICE_FUNC To float8_base::ConvertTo(const Derived& from) { +EIGEN_DEVICE_FUNC To float8_base::ConvertTo(const Derived from) { return ConvertImpl::run(from); } diff --git a/ml_dtypes/tests/float8_test.cc b/ml_dtypes/tests/float8_test.cc index b8dbb685..765428b8 100644 --- a/ml_dtypes/tests/float8_test.cc +++ b/ml_dtypes/tests/float8_test.cc @@ -407,6 +407,40 @@ TYPED_TEST(Float8Test, ConvertTo) { } } +template +static SrcType DoubleRoundHelper() { + // If we have a number of the form 1.0..010..010.., two rounds of RTNE can + // cause the last-set bit to get rounded down due to RTNE which in turn will + // cause the other bit to get rounded down due to RTNE. RTNE's tie breaking + // semantics *should* not apply here as there is no tie but double-rounding + // may confuse us. + SrcType x{1.0}; + x += std::ldexp(SrcType{1.0}, -std::numeric_limits::digits); + x += std::ldexp(SrcType{1.0}, -std::numeric_limits::digits); + auto rounded_x = static_cast(x); + return static_cast(rounded_x); +} + +// This test tries to capture mistakes in `float8_base::ConverFrom` where it is +// implemented by a series of conversions. e.g. converting a double to a float +// to a float8 introduces double-rounding which makes the final rounding step +// unfaithful. Craft a variety of numbers which try to detect if this happens. +TYPED_TEST(Float8Test, DoubleRound) { + using Float8 = TypeParam; + + // We expect that our number results in rounding up to the number after 1. + // Incorrect rounding will result in 1. + const double expected = + 1.0 + static_cast(std::numeric_limits::epsilon()); + + // Don't use long double on targets which don't support it. +#if !defined(EIGEN_USE_GPU) && !defined(EIGEN_GPU_COMPILE_PHASE) + EXPECT_EQ((DoubleRoundHelper()), expected); + EXPECT_EQ((DoubleRoundHelper()), expected); +#endif + EXPECT_EQ((DoubleRoundHelper()), expected); +} + TEST(Float8Test, Float8E5m2_To_Float8E4m3) { // Saturation. float8_e5m2 max = std::numeric_limits::max(); @@ -677,7 +711,7 @@ TYPED_TEST(Float8Test, CallTheConstOperator) { } } -TEST(Float855m2Test, SmallCastToDenormal) { +TEST(Float8E5m2Test, SmallCastToDenormal) { // Special edge-case where rounding to a normalized value would // normally round down, but rounding to a subnormal rounds up. float x = std::ldexp(1.3125, -15);