Skip to content

Commit

Permalink
Merge pull request #182 from graphcore-research:adding-clang-format-p…
Browse files Browse the repository at this point in the history
…re-commit

PiperOrigin-RevId: 671057090
  • Loading branch information
The ml_dtypes Authors committed Sep 4, 2024
2 parents acf7e8c + c864dc6 commit 17a83f1
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 15 deletions.
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,8 @@ repos:
rev: 8b5112a3b2ad121439a2092f8ff548c0d80f2514 # frozen: v0.6.1
hooks:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: 05241dc3def184dba136e62d54ff57f1c8a497a9 # frozen: v17.0.6
hooks:
- id: clang-format
files: ml_dtypes/
2 changes: 1 addition & 1 deletion ml_dtypes/_src/dtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct TypeDescriptor<float8_e3m4> : CustomFloatType<float8_e3m4> {
static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e3m4";
static constexpr const char* kTpDoc = "float8_e3m4 floating-point values";
// Set e3m4 kind as Void since kind=f (float) with itemsize=1 is used by e5m2
static constexpr char kNpyDescrKind = 'V'; // Void
static constexpr char kNpyDescrKind = 'V'; // Void
static constexpr char kNpyDescrType = '3';
static constexpr char kNpyDescrByteorder = '='; // Native byte order
};
Expand Down
14 changes: 6 additions & 8 deletions ml_dtypes/include/float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,7 @@ class float8_e3m4 : public float8_base<float8_e3m4> {

public:
template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
explicit EIGEN_DEVICE_FUNC float8_e3m4(T f8)
: float8_e3m4(ConvertFrom(f8)) {}
explicit EIGEN_DEVICE_FUNC float8_e3m4(T f8) : float8_e3m4(ConvertFrom(f8)) {}
};

class float8_e4m3 : public float8_base<float8_e4m3> {
Expand All @@ -269,8 +268,7 @@ class float8_e4m3 : public float8_base<float8_e4m3> {

public:
template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
explicit EIGEN_DEVICE_FUNC float8_e4m3(T f8)
: float8_e4m3(ConvertFrom(f8)) {}
explicit EIGEN_DEVICE_FUNC float8_e4m3(T f8) : float8_e4m3(ConvertFrom(f8)) {}
};

class float8_e4m3fn : public float8_base<float8_e4m3fn> {
Expand Down Expand Up @@ -1481,25 +1479,25 @@ namespace numext {

template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC ml_dtypes::float8_e3m4
bit_cast<ml_dtypes::float8_e3m4, uint8_t>(const uint8_t &src) {
bit_cast<ml_dtypes::float8_e3m4, uint8_t>(const uint8_t& src) {
return ml_dtypes::float8_e3m4::FromRep(src);
}

template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint8_t
bit_cast<uint8_t, ml_dtypes::float8_e3m4>(const ml_dtypes::float8_e3m4 &src) {
bit_cast<uint8_t, ml_dtypes::float8_e3m4>(const ml_dtypes::float8_e3m4& src) {
return src.rep();
}

template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC ml_dtypes::float8_e4m3
bit_cast<ml_dtypes::float8_e4m3, uint8_t>(const uint8_t &src) {
bit_cast<ml_dtypes::float8_e4m3, uint8_t>(const uint8_t& src) {
return ml_dtypes::float8_e4m3::FromRep(src);
}

template <>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint8_t
bit_cast<uint8_t, ml_dtypes::float8_e4m3>(const ml_dtypes::float8_e4m3 &src) {
bit_cast<uint8_t, ml_dtypes::float8_e4m3>(const ml_dtypes::float8_e4m3& src) {
return src.rep();
}

Expand Down
11 changes: 5 additions & 6 deletions ml_dtypes/tests/float8_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ TEST(Float8E3m4Test, NumericLimits) {
Eigen::numext::isnan(std::numeric_limits<float8_e3m4>::quiet_NaN()));
EXPECT_TRUE(
Eigen::numext::isnan(std::numeric_limits<float8_e3m4>::signaling_NaN()));
EXPECT_EQ(static_cast<float>(std::numeric_limits<float8_e3m4>::min()),
0.25);
EXPECT_EQ(static_cast<float>(std::numeric_limits<float8_e3m4>::min()), 0.25);
EXPECT_EQ(static_cast<float>(std::numeric_limits<float8_e3m4>::max()), 15.5);
EXPECT_EQ(static_cast<float>(std::numeric_limits<float8_e3m4>::lowest()),
-15.5);
Expand Down Expand Up @@ -995,10 +994,10 @@ struct Float8CastTestParamNames {
std::pair<Type, float8_e5m2>, std::pair<Type, bool>, \
std::pair<Type, int32_t>, std::pair<Type, int64_t>

#define GEN_TYPE_PAIRS() \
GEN_DEST_TYPES(float8_e3m4), GEN_DEST_TYPES(float8_e4m3), \
GEN_DEST_TYPES(float8_e4m3fn), GEN_DEST_TYPES(float8_e4m3b11fnuz), \
GEN_DEST_TYPES(float8_e5m2), GEN_DEST_TYPES(float8_e4m3fnuz), \
#define GEN_TYPE_PAIRS() \
GEN_DEST_TYPES(float8_e3m4), GEN_DEST_TYPES(float8_e4m3), \
GEN_DEST_TYPES(float8_e4m3fn), GEN_DEST_TYPES(float8_e4m3b11fnuz), \
GEN_DEST_TYPES(float8_e5m2), GEN_DEST_TYPES(float8_e4m3fnuz), \
GEN_DEST_TYPES(float8_e5m2fnuz)

using Float8CastTypePairs = ::testing::Types<GEN_TYPE_PAIRS()>;
Expand Down

0 comments on commit 17a83f1

Please sign in to comment.