diff --git a/ml_dtypes/include/mxfloat.h b/ml_dtypes/include/mxfloat.h index 2741cecc..9a02862b 100644 --- a/ml_dtypes/include/mxfloat.h +++ b/ml_dtypes/include/mxfloat.h @@ -338,7 +338,7 @@ struct Traits namespace Eigen { namespace numext { -#define MXFLOAT_EIGEN_BITCAST_IMPL(Type) \ +#define MXFLOAT_EIGEN_BITCAST_AND_SIGNBIT_IMPL(Type) \ template <> \ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint8_t bit_cast( \ const Type& x) { \ @@ -348,11 +348,16 @@ namespace numext { EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Type bit_cast( \ const uint8_t& x) { \ return Type::FromRep(x); \ + } \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Type signbit(const Type& x) { \ + int8_t t = bit_cast(x) << (8 - Type::kBits); \ + return bit_cast(t >> 7); \ } -MXFLOAT_EIGEN_BITCAST_IMPL(ml_dtypes::float6_e2m3fn) -MXFLOAT_EIGEN_BITCAST_IMPL(ml_dtypes::float6_e3m2fn) -MXFLOAT_EIGEN_BITCAST_IMPL(ml_dtypes::float4_e2m1fn) +MXFLOAT_EIGEN_BITCAST_AND_SIGNBIT_IMPL(ml_dtypes::float6_e2m3fn) +MXFLOAT_EIGEN_BITCAST_AND_SIGNBIT_IMPL(ml_dtypes::float6_e3m2fn) +MXFLOAT_EIGEN_BITCAST_AND_SIGNBIT_IMPL(ml_dtypes::float4_e2m1fn) #undef MXFLOAT_EIGEN_BITCAST_IMPL diff --git a/ml_dtypes/tests/mxfloat_test.cc b/ml_dtypes/tests/mxfloat_test.cc index 834c1055..53fad2e0 100644 --- a/ml_dtypes/tests/mxfloat_test.cc +++ b/ml_dtypes/tests/mxfloat_test.cc @@ -125,6 +125,14 @@ TYPED_TEST(FloatMXTest, Negate) { } } +TYPED_TEST(FloatMXTest, Signbit) { + using FloatMX = TypeParam; + + FloatMX one(1.0); + EXPECT_EQ(Eigen::numext::signbit(one).rep(), 0x00); + EXPECT_EQ(Eigen::numext::signbit(-one).rep(), 0xff); +} + TYPED_TEST(FloatMXTest, BitCasts) { using FloatMX = TypeParam;