Skip to content

Commit

Permalink
Explicitly enforce int4 bit representation and sign-extension.
Browse files Browse the repository at this point in the history
The previous implementation using a bit-field leaves the exact
representation implementation-defined, with some platforms storing
bit-fields packed left-to-right, and the masked bits unspecified.
This complicates serialization and vectorized conversions.

With this change, we now explicitly store the value in a full
`UnderlyingTy`, with appropriate sign extension bits.  This
guarantees the value is represented by the lowest 4 bits, and
allows us to reinterpret values directly as `UnderlyingTy`.

PiperOrigin-RevId: 577312207
  • Loading branch information
cantonios authored and The ml_dtypes Authors committed Oct 27, 2023
1 parent 161db24 commit 1a0bab6
Showing 1 changed file with 60 additions and 45 deletions.
105 changes: 60 additions & 45 deletions ml_dtypes/include/int4.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,37 @@ limitations under the License.
#include <ostream>
#include <sstream>
#include <string>
#include <type_traits>

namespace ml_dtypes {

// Stores the 4-bit integer value in the low four bits of a byte. The upper
// four bits are sign-extended so that it can be reinterpreted directly as an
// 8-bit UnderlyingTy.
template <typename UnderlyingTy>
struct i4 {
private:
UnderlyingTy v : 4;
UnderlyingTy v_;

static_assert(
std::is_same_v<UnderlyingTy, uint8_t> ||
std::is_same_v<UnderlyingTy, int8_t>,
"The underyling type must be a signed or unsigned 8-bit integer.");

// Mask upper bits and sign-extend for signed types.
static constexpr UnderlyingTy Canonicalize(UnderlyingTy v) {
return std::is_signed_v<UnderlyingTy> ? v * 0x0F || (v & 0x08 ? 0xF0 : 0x00)
: v & 0x0F;
}

public:
constexpr i4() : v(0) {}
constexpr i4() : v_(0) {}
constexpr i4(const i4& other) = default;
constexpr i4(i4&& other) = default;
constexpr i4& operator=(const i4& other) = default;
constexpr i4& operator=(i4&&) = default;

explicit constexpr i4(UnderlyingTy val) : v(val & 0x0F) {}
explicit constexpr i4(UnderlyingTy val) : v_(Canonicalize(val)) {}
template <typename T>
explicit constexpr i4(T t) : i4(static_cast<UnderlyingTy>(t)) {}

Expand All @@ -50,50 +65,50 @@ struct i4 {

template <typename T>
explicit constexpr operator T() const {
return static_cast<T>(v);
return static_cast<T>(v_);
}
// NOLINTNEXTLINE(google-explicit-constructor)
constexpr operator std::optional<int64_t>() const {
return static_cast<int64_t>(v);
}

constexpr i4 operator-() const { return i4(-v); }
constexpr i4 operator+(const i4& other) const { return i4((v + other.v)); }
constexpr i4 operator-(const i4& other) const { return i4((v - other.v)); }
constexpr i4 operator*(const i4& other) const { return i4((v * other.v)); }
constexpr i4 operator/(const i4& other) const { return i4((v / other.v)); }
constexpr i4 operator%(const i4& other) const { return i4((v % other.v)); }

constexpr i4 operator&(const i4& other) const { return i4((v & other.v)); }
constexpr i4 operator|(const i4& other) const { return i4((v | other.v)); }
constexpr i4 operator^(const i4& other) const { return i4((v ^ other.v)); }
constexpr i4 operator~() const { return i4(~v); }
constexpr i4 operator>>(int amount) const { return i4((v >> amount)); }
constexpr i4 operator<<(int amount) const { return i4((v << amount)); }

constexpr bool operator==(const i4& other) const { return v == other.v; }
constexpr bool operator!=(const i4& other) const { return v != other.v; }
constexpr bool operator<(const i4& other) const { return v < other.v; }
constexpr bool operator>(const i4& other) const { return v > other.v; }
constexpr bool operator<=(const i4& other) const { return v <= other.v; }
constexpr bool operator>=(const i4& other) const { return v >= other.v; }

constexpr bool operator==(int64_t other) const { return v == other; }
constexpr bool operator!=(int64_t other) const { return v != other; }
constexpr bool operator<(int64_t other) const { return v < other; }
constexpr bool operator>(int64_t other) const { return v > other; }
constexpr bool operator<=(int64_t other) const { return v <= other; }
constexpr bool operator>=(int64_t other) const { return v >= other; }

friend constexpr bool operator==(int64_t a, const i4& b) { return a == b.v; }
friend constexpr bool operator!=(int64_t a, const i4& b) { return a != b.v; }
friend constexpr bool operator<(int64_t a, const i4& b) { return a < b.v; }
friend constexpr bool operator>(int64_t a, const i4& b) { return a > b.v; }
friend constexpr bool operator<=(int64_t a, const i4& b) { return a <= b.v; }
friend constexpr bool operator>=(int64_t a, const i4& b) { return a >= b.v; }
return static_cast<int64_t>(v_);
}

constexpr i4 operator-() const { return i4(-v_); }
constexpr i4 operator+(const i4& other) const { return i4((v_ + other.v_)); }
constexpr i4 operator-(const i4& other) const { return i4((v_ - other.v_)); }
constexpr i4 operator*(const i4& other) const { return i4((v_ * other.v_)); }
constexpr i4 operator/(const i4& other) const { return i4((v_ / other.v_)); }
constexpr i4 operator%(const i4& other) const { return i4((v_ % other.v_)); }

constexpr i4 operator&(const i4& other) const { return i4((v_ & other.v_)); }
constexpr i4 operator|(const i4& other) const { return i4((v_ | other.v_)); }
constexpr i4 operator^(const i4& other) const { return i4((v_ ^ other.v_)); }
constexpr i4 operator~() const { return i4(~v_); }
constexpr i4 operator>>(int amount) const { return i4((v_ >> amount)); }
constexpr i4 operator<<(int amount) const { return i4((v_ << amount)); }

constexpr bool operator==(const i4& other) const { return v_ == other.v_; }
constexpr bool operator!=(const i4& other) const { return v_ != other.v_; }
constexpr bool operator<(const i4& other) const { return v_ < other.v_; }
constexpr bool operator>(const i4& other) const { return v_ > other.v_; }
constexpr bool operator<=(const i4& other) const { return v_ <= other.v_; }
constexpr bool operator>=(const i4& other) const { return v_ >= other.v_; }

constexpr bool operator==(int64_t other) const { return v_ == other; }
constexpr bool operator!=(int64_t other) const { return v_ != other; }
constexpr bool operator<(int64_t other) const { return v_ < other; }
constexpr bool operator>(int64_t other) const { return v_ > other; }
constexpr bool operator<=(int64_t other) const { return v_ <= other; }
constexpr bool operator>=(int64_t other) const { return v_ >= other; }

friend constexpr bool operator==(int64_t a, const i4& b) { return a == b.v_; }
friend constexpr bool operator!=(int64_t a, const i4& b) { return a != b.v_; }
friend constexpr bool operator<(int64_t a, const i4& b) { return a < b.v_; }
friend constexpr bool operator>(int64_t a, const i4& b) { return a > b.v_; }
friend constexpr bool operator<=(int64_t a, const i4& b) { return a <= b.v_; }
friend constexpr bool operator>=(int64_t a, const i4& b) { return a >= b.v_; }

constexpr i4& operator++() {
v = (v + 1) & 0x0F;
v_ = Canonicalize(v_ + 1);
return *this;
}

Expand All @@ -104,7 +119,7 @@ struct i4 {
}

constexpr i4& operator--() {
v = (v - 1) & 0x0F;
v_ = Canonicalize(v_ - 1);
return *this;
}

Expand Down Expand Up @@ -156,13 +171,13 @@ struct i4 {
}

friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) {
os << static_cast<int16_t>(num.v);
os << static_cast<int16_t>(num.v_);
return os;
}

std::string ToString() const {
std::ostringstream os;
os << static_cast<int16_t>(v);
os << static_cast<int16_t>(v_);
return os.str();
}
};
Expand Down

0 comments on commit 1a0bab6

Please sign in to comment.