Skip to content

Commit

Permalink
Allow the copy constructor to use copy-by-value as float8 types are q…
Browse files Browse the repository at this point in the history
…uite small

No functional change is intended.

PiperOrigin-RevId: 576146641
  • Loading branch information
majnemer authored and The ml_dtypes Authors committed Oct 25, 2023
1 parent 348fd37 commit 0fb78fc
Showing 1 changed file with 15 additions and 27 deletions.
42 changes: 15 additions & 27 deletions ml_dtypes/include/float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ class float8_base {
public:
constexpr float8_base() : rep_(0) {}

template <typename T,
typename EnableIf = std::enable_if<std::is_arithmetic_v<T>>>
explicit EIGEN_DEVICE_FUNC float8_base(T f)
template <typename T>
explicit EIGEN_DEVICE_FUNC float8_base(
T f, std::enable_if_t<std::is_arithmetic_v<T>, int> = 0)
: float8_base(ConvertFrom(static_cast<float>(f)).rep(),
ConstructFromRepTag{}) {}
explicit EIGEN_DEVICE_FUNC float8_base(double f64)
Expand Down Expand Up @@ -239,6 +239,10 @@ class float8_base {
uint8_t rep_;
};

template <typename T>
using RequiresIsDerivedFromFloat8Base =
std::enable_if_t<std::is_base_of_v<float8_base<T>, T>, int>;

class float8_e4m3fn : public float8_base<float8_e4m3fn> {
// Exponent: 4, Mantissa: 3, bias: 7.
// Extended range: no inf, NaN represented by 0bS111'1111.
Expand All @@ -252,9 +256,8 @@ class float8_e4m3fn : public float8_base<float8_e4m3fn> {
using Base::Base;

public:
explicit EIGEN_DEVICE_FUNC float8_e4m3fn(const float8_e5m2& f8)
: float8_e4m3fn(ConvertFrom(f8)) {}
explicit EIGEN_DEVICE_FUNC float8_e4m3fn(const float8_e4m3b11fnuz& f8)
template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
explicit EIGEN_DEVICE_FUNC float8_e4m3fn(T f8)
: float8_e4m3fn(ConvertFrom(f8)) {}
};

Expand All @@ -267,13 +270,8 @@ class float8_e4m3b11fnuz : public float8_base<float8_e4m3b11fnuz> {
using Base::Base;

public:
explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(const float8_e5m2& f8)
: float8_e4m3b11fnuz(ConvertFrom(f8)) {}
explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(const float8_e5m2fnuz& f8)
: float8_e4m3b11fnuz(ConvertFrom(f8)) {}
explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(const float8_e4m3fn& f8)
: float8_e4m3b11fnuz(ConvertFrom(f8)) {}
explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(const float8_e4m3fnuz& f8)
template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(T f8)
: float8_e4m3b11fnuz(ConvertFrom(f8)) {}

constexpr float8_e4m3b11fnuz operator-() const {
Expand Down Expand Up @@ -315,13 +313,8 @@ class float8_e4m3fnuz : public float8_base<float8_e4m3fnuz> {
using Base::Base;

public:
explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(const float8_e5m2& f8)
: float8_e4m3fnuz(ConvertFrom(f8)) {}
explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(const float8_e5m2fnuz& f8)
: float8_e4m3fnuz(ConvertFrom(f8)) {}
explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(const float8_e4m3b11fnuz& f8)
: float8_e4m3fnuz(ConvertFrom(f8)) {}
explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(const float8_e4m3fn& f8)
template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(T f8)
: float8_e4m3fnuz(ConvertFrom(f8)) {}

constexpr float8_e4m3fnuz operator-() const {
Expand All @@ -347,13 +340,8 @@ class float8_e5m2 : public float8_base<float8_e5m2> {
using Base::Base;

public:
explicit EIGEN_DEVICE_FUNC float8_e5m2(float8_e4m3fn f8)
: float8_e5m2(ConvertFrom(f8)) {}
explicit EIGEN_DEVICE_FUNC float8_e5m2(float8_e4m3fnuz f8)
: float8_e5m2(ConvertFrom(f8)) {}
explicit EIGEN_DEVICE_FUNC float8_e5m2(float8_e4m3b11fnuz f8)
: float8_e5m2(ConvertFrom(f8)) {}
explicit EIGEN_DEVICE_FUNC float8_e5m2(float8_e5m2fnuz& f8)
template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
explicit EIGEN_DEVICE_FUNC float8_e5m2(T f8)
: float8_e5m2(ConvertFrom(f8)) {}
};

Expand Down

0 comments on commit 0fb78fc

Please sign in to comment.