From 28cc246d95fa05bec0019eab0c507cbdca5a0e34 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Thu, 31 Aug 2023 15:32:38 -0700 Subject: [PATCH] Improve int4 constexpr-ness, add more operators, numeric_limits. This is to allow better support for int4 in C++. PiperOrigin-RevId: 561778414 --- ml_dtypes/include/int4.h | 227 +++++++++++++++++--- ml_dtypes/tests/int4_test.cc | 397 +++++++++++++++++++++++++++++++++++ 2 files changed, 592 insertions(+), 32 deletions(-) create mode 100644 ml_dtypes/tests/int4_test.cc diff --git a/ml_dtypes/include/int4.h b/ml_dtypes/include/int4.h index 51b2c2d5..c8ec5b06 100644 --- a/ml_dtypes/include/int4.h +++ b/ml_dtypes/include/int4.h @@ -17,6 +17,7 @@ limitations under the License. #define ML_DTYPES_INT4_H_ #include +#include #include #include #include @@ -30,11 +31,15 @@ struct i4 { UnderlyingTy v : 4; public: - i4() : v(0) {} - explicit i4(UnderlyingTy val) : v(val & 0x0F) {} + constexpr i4() : v(0) {} + constexpr i4(const i4& other) = default; + constexpr i4(i4&& other) = default; + constexpr i4& operator=(const i4& other) = default; + constexpr i4& operator=(i4&&) = default; + + explicit constexpr i4(UnderlyingTy val) : v(val & 0x0F) {} template - explicit i4(T t) : i4(static_cast(t)) {} - i4(const i4& other) = default; + explicit constexpr i4(T t) : i4(static_cast(t)) {} static constexpr i4 lowest() { return std::is_signed::value ? i4(-8) : i4(0); @@ -44,41 +49,112 @@ struct i4 { } template >> - explicit operator T() const { + explicit constexpr operator T() const { return static_cast(v); } // NOLINTNEXTLINE(google-explicit-constructor) - operator std::optional() const { return static_cast(v); } - - i4 operator-() const { return i4(-v); } - i4 operator+(const i4& other) const { return i4((v + other.v)); } - i4 operator-(const i4& other) const { return i4((v - other.v)); } - i4 operator*(const i4& other) const { return i4((v * other.v)); } - i4 operator/(const i4& other) const { return i4((v / other.v)); } - i4 operator%(const i4& other) const { return i4((v % other.v)); } - - i4 operator>>(const int amount) const { return i4((v >> amount)); } - i4 operator<<(const int amount) const { return i4((v << amount)); } - - bool operator==(const i4& other) const { return v == other.v; } - bool operator!=(const i4& other) const { return v != other.v; } - bool operator<(const i4& other) const { return v < other.v; } - bool operator>(const i4& other) const { return v > other.v; } - bool operator<=(const i4& other) const { return v <= other.v; } - bool operator>=(const i4& other) const { return v >= other.v; } - - bool operator==(const int64_t other) const { return v == other; } - bool operator!=(const int64_t other) const { return v != other; } - bool operator<(const int64_t other) const { return v < other; } - bool operator>(const int64_t other) const { return v > other; } - bool operator<=(const int64_t other) const { return v <= other; } - bool operator>=(const int64_t other) const { return v >= other; } - - i4& operator++() { + constexpr operator std::optional() const { + return static_cast(v); + } + + constexpr i4 operator-() const { return i4(-v); } + constexpr i4 operator+(const i4& other) const { return i4((v + other.v)); } + constexpr i4 operator-(const i4& other) const { return i4((v - other.v)); } + constexpr i4 operator*(const i4& other) const { return i4((v * other.v)); } + constexpr i4 operator/(const i4& other) const { return i4((v / other.v)); } + constexpr i4 operator%(const i4& other) const { return i4((v % other.v)); } + + constexpr i4 operator&(const i4& other) const { return i4((v & other.v)); } + constexpr i4 operator|(const i4& other) const { return i4((v | other.v)); } + constexpr i4 operator^(const i4& other) const { return i4((v ^ other.v)); } + constexpr i4 operator~() const { return i4(~v); } + constexpr i4 operator>>(int amount) const { return i4((v >> amount)); } + constexpr i4 operator<<(int amount) const { return i4((v << amount)); } + + constexpr bool operator==(const i4& other) const { return v == other.v; } + constexpr bool operator!=(const i4& other) const { return v != other.v; } + constexpr bool operator<(const i4& other) const { return v < other.v; } + constexpr bool operator>(const i4& other) const { return v > other.v; } + constexpr bool operator<=(const i4& other) const { return v <= other.v; } + constexpr bool operator>=(const i4& other) const { return v >= other.v; } + + constexpr bool operator==(int64_t other) const { return v == other; } + constexpr bool operator!=(int64_t other) const { return v != other; } + constexpr bool operator<(int64_t other) const { return v < other; } + constexpr bool operator>(int64_t other) const { return v > other; } + constexpr bool operator<=(int64_t other) const { return v <= other; } + constexpr bool operator>=(int64_t other) const { return v >= other; } + + friend constexpr bool operator==(int64_t a, const i4& b) { return a == b.v; } + friend constexpr bool operator!=(int64_t a, const i4& b) { return a != b.v; } + friend constexpr bool operator<(int64_t a, const i4& b) { return a < b.v; } + friend constexpr bool operator>(int64_t a, const i4& b) { return a > b.v; } + friend constexpr bool operator<=(int64_t a, const i4& b) { return a <= b.v; } + friend constexpr bool operator>=(int64_t a, const i4& b) { return a >= b.v; } + + constexpr i4& operator++() { v = (v + 1) & 0x0F; return *this; } + constexpr i4 operator++(int) { + i4 orig = *this; + this->operator++(); + return orig; + } + + constexpr i4& operator--() { + v = (v - 1) & 0x0F; + return *this; + } + + constexpr i4 operator--(int) { + i4 orig = *this; + this->operator--(); + return orig; + } + + constexpr i4& operator+=(const i4& other) { + *this = *this + other; + return *this; + } + constexpr i4& operator-=(const i4& other) { + *this = *this - other; + return *this; + } + constexpr i4& operator*=(const i4& other) { + *this = *this * other; + return *this; + } + constexpr i4& operator/=(const i4& other) { + *this = *this / other; + return *this; + } + constexpr i4& operator%=(const i4& other) { + *this = *this % other; + return *this; + } + constexpr i4& operator&=(const i4& other) { + *this = *this & other; + return *this; + } + constexpr i4& operator|=(const i4& other) { + *this = *this | other; + return *this; + } + constexpr i4& operator^=(const i4& other) { + *this = *this ^ other; + return *this; + } + constexpr i4& operator>>=(int amount) { + *this = *this >> amount; + return *this; + } + constexpr i4& operator<<=(int amount) { + *this = *this << amount; + return *this; + } + friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) { os << static_cast(num.v); return os; @@ -94,6 +170,93 @@ struct i4 { using int4 = i4; using uint4 = i4; +namespace internal { + +struct int4_numeric_limits_base { + static inline constexpr const bool is_specialized = true; + static inline constexpr const bool is_integer = true; + static inline constexpr const bool is_exact = true; + static inline constexpr const bool has_infinity = false; + static inline constexpr const bool has_quiet_NaN = false; + static inline constexpr const bool has_signaling_NaN = false; + static inline constexpr const std::float_denorm_style has_denorm = + std::denorm_absent; + static inline constexpr const bool has_denorm_loss = false; + static inline constexpr const std::float_round_style round_style = + std::round_toward_zero; + static inline constexpr const bool is_iec559 = false; + static inline constexpr const bool is_bounded = true; + static inline constexpr const int max_digits10 = 0; // Not used for integers. + static inline constexpr const int radix = 2; + static inline constexpr const int min_exponent = 0; + static inline constexpr const int min_exponent10 = 0; + static inline constexpr const int max_exponent = 0; + static inline constexpr const int max_exponent10 = 0; + static inline constexpr const bool traps = true; + static inline constexpr const bool tinyness_before = false; + + static constexpr ml_dtypes::int4 epsilon() noexcept { + return ml_dtypes::int4(0); + } + static constexpr ml_dtypes::int4 round_error() noexcept { + return ml_dtypes::int4(0); + } + static constexpr ml_dtypes::int4 infinity() noexcept { + return ml_dtypes::int4(0); + } + static constexpr ml_dtypes::int4 quiet_NaN() noexcept { + return ml_dtypes::int4(0); + } + static constexpr ml_dtypes::int4 signaling_NaN() noexcept { + return ml_dtypes::int4(0); + } + static constexpr ml_dtypes::int4 denorm_min() noexcept { + return ml_dtypes::int4(0); + } +}; + +} // namespace internal + } // namespace ml_dtypes +namespace std { + +template <> +struct numeric_limits + : public ml_dtypes::internal::int4_numeric_limits_base { + static inline constexpr const bool is_signed = true; + static inline constexpr const bool is_modulo = false; + static inline constexpr const int digits = 3; + static inline constexpr const int digits10 = 0; // floor(3 * log10(2)) + static constexpr ml_dtypes::int4 min() noexcept { + return ml_dtypes::int4::lowest(); + } + static constexpr ml_dtypes::int4 lowest() noexcept { + return ml_dtypes::int4::lowest(); + } + static constexpr ml_dtypes::int4 max() noexcept { + return ml_dtypes::int4::highest(); + } +}; + +template <> +struct numeric_limits + : public ml_dtypes::internal::int4_numeric_limits_base { + static inline constexpr const bool is_signed = false; + static inline constexpr const bool is_modulo = true; + static inline constexpr const int digits = 4; + static inline constexpr const int digits10 = 1; // floor(4 * log10(2)) + static constexpr ml_dtypes::uint4 min() noexcept { + return ml_dtypes::uint4::lowest(); + } + static constexpr ml_dtypes::uint4 lowest() noexcept { + return ml_dtypes::uint4::lowest(); + } + static constexpr ml_dtypes::uint4 max() noexcept { + return ml_dtypes::uint4::highest(); + } +}; + +} // namespace std + #endif // ML_DTYPES_INT4_H_ diff --git a/ml_dtypes/tests/int4_test.cc b/ml_dtypes/tests/int4_test.cc new file mode 100644 index 00000000..fc0304d5 --- /dev/null +++ b/ml_dtypes/tests/int4_test.cc @@ -0,0 +1,397 @@ +/* Copyright 2023 The ml_dtypes Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "include/int4.h" + +#include +#include +#include +#include +#include +#include + +#include +#include "Eigen/Core" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace ml_dtypes { +namespace { + +template +class Int4Test : public ::testing::Test {}; + +// Helper utility for prettier test names. +struct Int4TestParamNames { + template + static std::string GetName(int idx) { + if constexpr (std::is_same_v) { + return "int4"; + } else if constexpr (std::is_same_v) { + return "uint4"; + } + return std::to_string(idx); + } +}; + +using Int4Types = ::testing::Types; +TYPED_TEST_SUITE(Int4Test, Int4Types, Int4TestParamNames); + +TEST(Int4Test, NumericLimits) { + EXPECT_EQ(std::numeric_limits::is_signed, true); + EXPECT_EQ(std::numeric_limits::is_modulo, false); + EXPECT_EQ(static_cast(std::numeric_limits::min()), -8); + EXPECT_EQ(static_cast(std::numeric_limits::max()), 7); + EXPECT_EQ(static_cast(std::numeric_limits::lowest()), -8); + EXPECT_EQ(std::numeric_limits::digits, 3); + EXPECT_EQ(std::numeric_limits::digits10, 0); +} + +TEST(UInt4Test, NumericLimits) { + EXPECT_EQ(std::numeric_limits::is_signed, false); + EXPECT_EQ(std::numeric_limits::is_modulo, true); + EXPECT_EQ(static_cast(std::numeric_limits::min()), 0); + EXPECT_EQ(static_cast(std::numeric_limits::max()), 15); + EXPECT_EQ(static_cast(std::numeric_limits::lowest()), 0); + EXPECT_EQ(std::numeric_limits::digits, 4); + EXPECT_EQ(std::numeric_limits::digits10, 1); +} + +TYPED_TEST(Int4Test, NumericLimitsBase) { + using Int4 = TypeParam; + EXPECT_EQ(std::numeric_limits::is_specialized, true); + EXPECT_EQ(std::numeric_limits::is_integer, true); + EXPECT_EQ(std::numeric_limits::is_exact, true); + EXPECT_EQ(std::numeric_limits::has_infinity, false); + EXPECT_EQ(std::numeric_limits::has_quiet_NaN, false); + EXPECT_EQ(std::numeric_limits::has_signaling_NaN, false); + EXPECT_EQ(std::numeric_limits::has_denorm, std::denorm_absent); + EXPECT_EQ(std::numeric_limits::has_denorm_loss, false); + EXPECT_EQ(std::numeric_limits::round_style, std::round_toward_zero); + EXPECT_EQ(std::numeric_limits::is_iec559, false); + EXPECT_EQ(std::numeric_limits::is_bounded, true); + EXPECT_EQ(std::numeric_limits::radix, 2); + EXPECT_EQ(std::numeric_limits::min_exponent, 0); + EXPECT_EQ(std::numeric_limits::min_exponent10, 0); + EXPECT_EQ(std::numeric_limits::max_exponent, 0); + EXPECT_EQ(std::numeric_limits::max_exponent10, 0); + EXPECT_EQ(std::numeric_limits::traps, true); + EXPECT_EQ(std::numeric_limits::tinyness_before, false); + EXPECT_EQ(static_cast(std::numeric_limits::epsilon()), 0); + EXPECT_EQ(static_cast(std::numeric_limits::round_error()), 0); + EXPECT_EQ(static_cast(std::numeric_limits::infinity()), 0); + EXPECT_EQ(static_cast(std::numeric_limits::quiet_NaN()), 0); + EXPECT_EQ(static_cast(std::numeric_limits::signaling_NaN()), 0); + EXPECT_EQ(static_cast(std::numeric_limits::denorm_min()), 0); +} + +TYPED_TEST(Int4Test, CreateAndAssign) { + using Int4 = TypeParam; + + // Constructors. + EXPECT_EQ(Int4(), Int4(0)); + Int4 a(1); + EXPECT_EQ(a, Int4(1)); + Int4 b(std::move(a)); + EXPECT_EQ(b, Int4(1)); + + // Assignments. + EXPECT_EQ(a = Int4(2), Int4(2)); + EXPECT_EQ(b = a, Int4(2)); + EXPECT_EQ((a = Int4(3), b = std::move(a)), Int4(3)); +} + +// To ensure an expression is evaluated in a constexpr context, +// we use the trick of inserting the expression in a template +// parameter. +template +struct ConstexprEvaluator { + static constexpr bool val = true; +}; + +// To avoid warnings about unused left-side of comma expressions, +// we additionally pass the expression through a contexpr function. +template +constexpr void ConstexprEvaluatorFunc(T&&){} + +#define TEST_CONSTEXPR(expr) \ + do { \ + EXPECT_TRUE((ConstexprEvaluator<(ConstexprEvaluatorFunc(expr), 1)>::val)); \ + } while (false) + +TYPED_TEST(Int4Test, Constexpr) { + TEST_CONSTEXPR(int4(0)); + TEST_CONSTEXPR(static_cast(int4(0))); + + TEST_CONSTEXPR(-int4(1)); + TEST_CONSTEXPR(int4(0) + int4(1)); + TEST_CONSTEXPR(int4(1) - int4(0)); + TEST_CONSTEXPR(int4(0) * int4(1)); + TEST_CONSTEXPR(int4(0) / int4(1)); + TEST_CONSTEXPR(int4(0) % int4(1)); + + TEST_CONSTEXPR(int4(1) & int4(0xF)); + TEST_CONSTEXPR(int4(1) | int4(0xF)); + TEST_CONSTEXPR(int4(1) ^ int4(0xF)); + TEST_CONSTEXPR(~int4(1)); + TEST_CONSTEXPR(int4(1) >> 1); + TEST_CONSTEXPR(int4(1) << 1); + + TEST_CONSTEXPR(int4(1) == int4(1)); + TEST_CONSTEXPR(int4(1) != int4(1)); + TEST_CONSTEXPR(int4(1) < int4(1)); + TEST_CONSTEXPR(int4(1) > int4(1)); + TEST_CONSTEXPR(int4(1) <= int4(1)); + TEST_CONSTEXPR(int4(1) >= int4(1)); + + TEST_CONSTEXPR(++int4(1)); + TEST_CONSTEXPR(int4(1)++); + TEST_CONSTEXPR(--int4(1)); + TEST_CONSTEXPR(int4(1)--); + + TEST_CONSTEXPR(int4(1) += int4(2)); + TEST_CONSTEXPR(int4(1) -= int4(2)); + TEST_CONSTEXPR(int4(1) *= int4(2)); + TEST_CONSTEXPR(int4(1) /= int4(2)); + TEST_CONSTEXPR(int4(1) %= int4(2)); + TEST_CONSTEXPR(int4(1) &= int4(2)); + TEST_CONSTEXPR(int4(1) |= int4(2)); + TEST_CONSTEXPR(int4(1) ^= int4(2)); + TEST_CONSTEXPR(int4(1) >>= 1); + TEST_CONSTEXPR(int4(1) <<= 1); +} + +TYPED_TEST(Int4Test, Casts) { + using Int4 = TypeParam; + + // Explicit integer types. + EXPECT_EQ(static_cast(Int4(4)), 4); + EXPECT_EQ(static_cast(Int4(5)), 5); + EXPECT_EQ(static_cast(Int4(6)), 6); + EXPECT_EQ(static_cast(Int4(7)), 7); + EXPECT_EQ(static_cast(Int4(1)), 1); + + // Implicit conversion to optional. + std::optional c = Int4(2); + EXPECT_EQ(c, 2); + + // Loop through all valid values. + for (int i = static_cast(std::numeric_limits::min()); + i <= static_cast(std::numeric_limits::max()); ++i) { + // Round-trip. + EXPECT_EQ(static_cast(Int4(i)), i); + + // Float truncation. + for (int j = 1; j < 10; ++j) { + float offset = -1.f + j * 1.f / 5; + float f = i + offset; + EXPECT_EQ(Int4(f), Int4(static_cast(f))); + } + } +} + +TYPED_TEST(Int4Test, Operators) { + using Int4 = TypeParam; + for (int i = static_cast(std::numeric_limits::min()); + i <= static_cast(std::numeric_limits::max()); ++i) { + Int4 x = Int4(i); + + EXPECT_EQ(-x, Int4(-i)); + EXPECT_EQ(~x, Int4(~i)); + Int4 a; + EXPECT_EQ((a = x, ++a), Int4(i + 1)); + EXPECT_EQ(a, Int4(i + 1)); + EXPECT_EQ((a = x, a++), Int4(i)); + EXPECT_EQ(a, Int4(i + 1)); + EXPECT_EQ((a = x, --a), Int4(i - 1)); + EXPECT_EQ(a, Int4(i - 1)); + EXPECT_EQ((a = x, a--), Int4(i)); + EXPECT_EQ(a, Int4(i - 1)); + + for (int j = static_cast(std::numeric_limits::min()); + j <= static_cast(std::numeric_limits::max()); ++j) { + Int4 y = Int4(j); + + EXPECT_EQ(x + y, Int4(i + j)); + EXPECT_EQ(x - y, Int4(i - j)); + EXPECT_EQ(x * y, Int4(i * j)); + if (j != 0) { + EXPECT_EQ(x / y, Int4(i / j)); + EXPECT_EQ(x % y, Int4(i % j)); + } + EXPECT_EQ(x & y, Int4(i & j)); + EXPECT_EQ(x | y, Int4(i | j)); + EXPECT_EQ(x ^ y, Int4(i ^ j)); + + EXPECT_EQ(x == y, i == j); + EXPECT_EQ(x != y, i != j); + EXPECT_EQ(x < y, i < j); + EXPECT_EQ(x > y, i > j); + EXPECT_EQ(x <= y, i <= j); + EXPECT_EQ(x >= y, i >= j); + + EXPECT_EQ(x == static_cast(j), i == j); + EXPECT_EQ(x != static_cast(j), i != j); + EXPECT_EQ(x < static_cast(j), i < j); + EXPECT_EQ(x > static_cast(j), i > j); + EXPECT_EQ(x <= static_cast(j), i <= j); + EXPECT_EQ(x >= static_cast(j), i >= j); + + EXPECT_EQ(static_cast(j) == x, j == i); + EXPECT_EQ(static_cast(j) != x, j != i); + EXPECT_EQ(static_cast(j) < x, j < i); + EXPECT_EQ(static_cast(j) > x, j > i); + EXPECT_EQ(static_cast(j) <= x, j <= i); + EXPECT_EQ(static_cast(j) >= x, j >= i); + + EXPECT_EQ((a = x, a += y), Int4(i + j)); + EXPECT_EQ((a = x, a -= y), Int4(i - j)); + EXPECT_EQ((a = x, a *= y), Int4(i * j)); + if (j != 0) { + EXPECT_EQ((a = x, a /= y), Int4(i / j)); + EXPECT_EQ((a = x, a %= y), Int4(i % j)); + } + EXPECT_EQ((a = x, a &= y), Int4(i & j)); + EXPECT_EQ((a = x, a |= y), Int4(i | j)); + EXPECT_EQ((a = x, a ^= y), Int4(i ^ j)); + } + + for (int amount = 0; amount < 4; ++amount) { + EXPECT_EQ(x >> amount, Int4(i >> amount)); + EXPECT_EQ(x << amount, Int4(i << amount)); + EXPECT_EQ((a = x, a >>= amount), Int4(i >> amount)); + EXPECT_EQ((a = x, a <<= amount), Int4(i << amount)); + } + } +} + +TYPED_TEST(Int4Test, ToString) { + using Int4 = TypeParam; + for (int i = static_cast(std::numeric_limits::min()); + i <= static_cast(std::numeric_limits::max()); ++i) { + Int4 x = Int4(i); + std::stringstream ss; + ss << x; + EXPECT_EQ(ss.str(), std::to_string(i)); + EXPECT_EQ(x.ToString(), std::to_string(i)); + } +} + +#define GEN_DEST_TYPES(Type) \ + std::pair, std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair + +#define GEN_TYPE_PAIRS() GEN_DEST_TYPES(int4), GEN_DEST_TYPES(uint4) + +using Int4CastTypePairs = ::testing::Types; +template +class Int4CastTest : public ::testing::Test {}; + +// Helper utility for prettier test names. +struct Int4CastTestParamNames { + template + static std::string GetName(int idx) { + using first_type = typename TypeParam::first_type; + using second_type = typename TypeParam::second_type; + return ::testing::internal::GetTypeName() + "_" + + ::testing::internal::GetTypeName(); + } +}; + +TYPED_TEST_SUITE(Int4CastTest, Int4CastTypePairs, Int4CastTestParamNames); + +TYPED_TEST(Int4CastTest, CastThroughInt) { + using Int4 = typename TypeParam::first_type; + using DestType = typename TypeParam::second_type; + + for (int i = 0x00; i <= 0x0F; ++i) { + Int4 x = Int4(i); + DestType dest = static_cast(x); + DestType expected = static_cast(static_cast(x)); + EXPECT_EQ(dest, expected); + } +} + +TYPED_TEST(Int4CastTest, DeviceCast) { + using Int4 = typename TypeParam::first_type; + using DestType = typename TypeParam::second_type; + +#if defined(EIGEN_USE_GPU) + Eigen::GpuStreamDevice stream; + Eigen::GpuDevice device(&stream); +#elif defined(EIGEN_USE_THREADS) + constexpr int kThreads = 4; + Eigen::ThreadPool tp(kThreads); + Eigen::ThreadPoolDevice device(&tp, kThreads); +#else + Eigen::DefaultDevice device; +#endif + + const int kNumElems = 256; + // Allocate device buffers and create device tensors. + Int4* src_device_buffer = (Int4*)device.allocate(kNumElems * sizeof(Int4)); + DestType* dst_device_buffer = + (DestType*)device.allocate(kNumElems * sizeof(DestType)); + + Eigen::TensorMap, Eigen::Aligned> src_device( + src_device_buffer, kNumElems); + Eigen::TensorMap, Eigen::Aligned> dst_device( + dst_device_buffer, kNumElems); + + // Allocate host buffers and initially src memory. + Eigen::Tensor src_cpu(kNumElems); + Eigen::Tensor dst_cpu(kNumElems); + for (int i = 0; i < kNumElems; ++i) { + src_cpu(i) = Eigen::numext::bit_cast(static_cast(i)); + } + + // Transfer data to device, perform a cast to DestType, then transfer result + // back to host. + device.memcpyHostToDevice(src_device_buffer, src_cpu.data(), + kNumElems * sizeof(Int4)); + dst_device.device(device) = src_device.template cast(); + device.memcpyDeviceToHost(dst_cpu.data(), dst_device_buffer, + kNumElems * sizeof(DestType)); + device.synchronize(); + + for (int i = 0; i < kNumElems; ++i) { + DestType expected = static_cast(src_cpu(i)); + EXPECT_EQ(dst_cpu(i), expected); + } + + // Cast back from DestType to Int4. + // First clear out the device src buffer, since that will be the destination. + src_cpu.setZero(); + device.memcpyHostToDevice(src_device_buffer, src_cpu.data(), + kNumElems * sizeof(Int4)); + src_device.device(device) = dst_device.template cast(); + device.memcpyDeviceToHost(src_cpu.data(), src_device_buffer, + kNumElems * sizeof(Int4)); + device.synchronize(); + + for (int i = 0; i < kNumElems; ++i) { + Int4 expected = static_cast(dst_cpu(i)); + EXPECT_EQ(src_cpu(i), expected); + } + + // Clean up. + device.deallocate(src_device_buffer); + device.deallocate(dst_device_buffer); + device.synchronize(); +} + +} // namespace +} // namespace ml_dtypes