diff --git a/ml_dtypes/__init__.py b/ml_dtypes/__init__.py index f22cfb12..6ee1e4e6 100644 --- a/ml_dtypes/__init__.py +++ b/ml_dtypes/__init__.py @@ -29,17 +29,16 @@ from typing import Type -from ml_dtypes._custom_floats import bfloat16 -from ml_dtypes._custom_floats import float8_e4m3b11fnuz -from ml_dtypes._custom_floats import float8_e4m3fn -from ml_dtypes._custom_floats import float8_e4m3fnuz -from ml_dtypes._custom_floats import float8_e5m2 -from ml_dtypes._custom_floats import float8_e5m2fnuz -from ml_dtypes._custom_floats import int4 -from ml_dtypes._custom_floats import uint4 from ml_dtypes._finfo import finfo from ml_dtypes._iinfo import iinfo - +from ml_dtypes._ml_dtypes_lib import bfloat16 +from ml_dtypes._ml_dtypes_lib import float8_e4m3b11fnuz +from ml_dtypes._ml_dtypes_lib import float8_e4m3fn +from ml_dtypes._ml_dtypes_lib import float8_e4m3fnuz +from ml_dtypes._ml_dtypes_lib import float8_e5m2 +from ml_dtypes._ml_dtypes_lib import float8_e5m2fnuz +from ml_dtypes._ml_dtypes_lib import int4 +from ml_dtypes._ml_dtypes_lib import uint4 import numpy as np bfloat16: Type[np.generic] diff --git a/ml_dtypes/_finfo.py b/ml_dtypes/_finfo.py index 1a78bd63..87c8ef7e 100644 --- a/ml_dtypes/_finfo.py +++ b/ml_dtypes/_finfo.py @@ -16,13 +16,12 @@ from typing import Dict -from ml_dtypes._custom_floats import bfloat16 -from ml_dtypes._custom_floats import float8_e4m3b11fnuz -from ml_dtypes._custom_floats import float8_e4m3fn -from ml_dtypes._custom_floats import float8_e4m3fnuz -from ml_dtypes._custom_floats import float8_e5m2 -from ml_dtypes._custom_floats import float8_e5m2fnuz - +from ml_dtypes._ml_dtypes_lib import bfloat16 +from ml_dtypes._ml_dtypes_lib import float8_e4m3b11fnuz +from ml_dtypes._ml_dtypes_lib import float8_e4m3fn +from ml_dtypes._ml_dtypes_lib import float8_e4m3fnuz +from ml_dtypes._ml_dtypes_lib import float8_e5m2 +from ml_dtypes._ml_dtypes_lib import float8_e5m2fnuz import numpy as np _bfloat16_dtype = np.dtype(bfloat16) diff --git a/ml_dtypes/_iinfo.py b/ml_dtypes/_iinfo.py index 1c27412b..0854ed4e 100644 --- a/ml_dtypes/_iinfo.py +++ b/ml_dtypes/_iinfo.py @@ -14,9 +14,8 @@ """Overload of numpy.iinfo to handle dtypes defined in ml_dtypes.""" -from ml_dtypes._custom_floats import int4 -from ml_dtypes._custom_floats import uint4 - +from ml_dtypes._ml_dtypes_lib import int4 +from ml_dtypes._ml_dtypes_lib import uint4 import numpy as np _int4_dtype = np.dtype(int4) diff --git a/ml_dtypes/_src/dtypes.cc b/ml_dtypes/_src/dtypes.cc index c0d30008..f735d42e 100644 --- a/ml_dtypes/_src/dtypes.cc +++ b/ml_dtypes/_src/dtypes.cc @@ -31,8 +31,9 @@ limitations under the License. #include "Eigen/Core" #include "_src/custom_float.h" -#include "_src/int4.h" +#include "_src/int4_numpy.h" #include "include/float8.h" +#include "include/int4.h" namespace ml_dtypes { @@ -297,7 +298,7 @@ bool Initialize() { static PyModuleDef module_def = { PyModuleDef_HEAD_INIT, - "_custom_floats", + "_ml_dtypes_lib", }; // TODO(phawkins): PyMODINIT_FUNC handles visibility correctly in Python 3.9+. @@ -308,14 +309,14 @@ static PyModuleDef module_def = { #define EXPORT_SYMBOL __attribute__((visibility("default"))) #endif -extern "C" EXPORT_SYMBOL PyObject* PyInit__custom_floats() { +extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_lib() { Safe_PyObjectPtr m = make_safe(PyModule_Create(&module_def)); if (!m) { return nullptr; } if (!Initialize()) { if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_RuntimeError, "cannot load _custom_floats module."); + PyErr_SetString(PyExc_RuntimeError, "cannot load _ml_dtypes_lib module."); } return nullptr; } diff --git a/ml_dtypes/_src/int4.h b/ml_dtypes/_src/int4_numpy.h similarity index 90% rename from ml_dtypes/_src/int4.h rename to ml_dtypes/_src/int4_numpy.h index ce855eb5..7f23fbc1 100644 --- a/ml_dtypes/_src/int4.h +++ b/ml_dtypes/_src/int4_numpy.h @@ -13,95 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef ML_DTYPES_INT4_H_ -#define ML_DTYPES_INT4_H_ +#ifndef ML_DTYPES_INT4_NUMPY_H_ +#define ML_DTYPES_INT4_NUMPY_H_ // Must be included first // clang-format off #include "_src/numpy.h" // clang-format on -#include //NOLINT -#include //NOLINT -#include //NOLINT -#include //NOLINT - #include "Eigen/Core" #include "_src/common.h" // NOLINT #include "_src/ufuncs.h" // NOLINT +#include "include/int4.h" namespace ml_dtypes { -template -struct i4 { - private: - UnderlyingTy v : 4; - - public: - i4() : v(0) {} - explicit i4(UnderlyingTy val) : v(val & 0x0F) {} - template - explicit i4(T t) : i4(static_cast(t)) {} - i4(const i4& other) = default; - - static constexpr i4 lowest() { - return std::is_signed::value ? i4(-8) : i4(0); - } - static constexpr i4 highest() { - return std::is_signed::value ? i4(7) : i4(15); - } - - template >> - explicit operator T() const { - return static_cast(v); - } - // NOLINTNEXTLINE(google-explicit-constructor) - operator std::optional() const { return static_cast(v); } - - i4 operator-() const { return i4(-v); } - i4 operator+(const i4& other) const { return i4((v + other.v)); } - i4 operator-(const i4& other) const { return i4((v - other.v)); } - i4 operator*(const i4& other) const { return i4((v * other.v)); } - i4 operator/(const i4& other) const { return i4((v / other.v)); } - i4 operator%(const i4& other) const { return i4((v % other.v)); } - - i4 operator>>(const int amount) const { return i4((v >> amount)); } - i4 operator<<(const int amount) const { return i4((v << amount)); } - - bool operator==(const i4& other) const { return v == other.v; } - bool operator!=(const i4& other) const { return v != other.v; } - bool operator<(const i4& other) const { return v < other.v; } - bool operator>(const i4& other) const { return v > other.v; } - bool operator<=(const i4& other) const { return v <= other.v; } - bool operator>=(const i4& other) const { return v >= other.v; } - - bool operator==(const int64_t other) const { return v == other; } - bool operator!=(const int64_t other) const { return v != other; } - bool operator<(const int64_t other) const { return v < other; } - bool operator>(const int64_t other) const { return v > other; } - bool operator<=(const int64_t other) const { return v <= other; } - bool operator>=(const int64_t other) const { return v >= other; } - - i4& operator++() { - v = (v + 1) & 0x0F; - return *this; - } - - friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) { - os << static_cast(num.v); - return os; - } - - std::string ToString() const { - std::ostringstream os; - os << static_cast(v); - return os.str(); - } -}; - -using int4 = i4; -using uint4 = i4; - template struct Int4TypeDescriptor { static int Dtype() { return npy_type; } @@ -878,4 +804,4 @@ bool RegisterInt4Dtype(PyObject* numpy) { } // namespace ml_dtypes -#endif // ML_DTYPES_INT4_H_ +#endif // ML_DTYPES_INT4_NUMPY_H_ diff --git a/ml_dtypes/include/int4.h b/ml_dtypes/include/int4.h new file mode 100644 index 00000000..b313940f --- /dev/null +++ b/ml_dtypes/include/int4.h @@ -0,0 +1,100 @@ +/* Copyright 2023 The ml_dtypes Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef ML_DTYPES_INT4_H_ +#define ML_DTYPES_INT4_H_ + +#include +#include +#include +#include +#include +#include + +namespace ml_dtypes { + +template +struct i4 { + private: + UnderlyingTy v : 4; + + public: + i4() : v(0) {} + explicit i4(UnderlyingTy val) : v(val & 0x0F) {} + template + explicit i4(T t) : i4(static_cast(t)) {} + i4(const i4& other) = default; + + static constexpr i4 lowest() { + return std::is_signed::value ? i4(-8) : i4(0); + } + static constexpr i4 highest() { + return std::is_signed::value ? i4(7) : i4(15); + } + + template >> + explicit operator T() const { + return static_cast(v); + } + // NOLINTNEXTLINE(google-explicit-constructor) + operator std::optional() const { return static_cast(v); } + + i4 operator-() const { return i4(-v); } + i4 operator+(const i4& other) const { return i4((v + other.v)); } + i4 operator-(const i4& other) const { return i4((v - other.v)); } + i4 operator*(const i4& other) const { return i4((v * other.v)); } + i4 operator/(const i4& other) const { return i4((v / other.v)); } + i4 operator%(const i4& other) const { return i4((v % other.v)); } + + i4 operator>>(const int amount) const { return i4((v >> amount)); } + i4 operator<<(const int amount) const { return i4((v << amount)); } + + bool operator==(const i4& other) const { return v == other.v; } + bool operator!=(const i4& other) const { return v != other.v; } + bool operator<(const i4& other) const { return v < other.v; } + bool operator>(const i4& other) const { return v > other.v; } + bool operator<=(const i4& other) const { return v <= other.v; } + bool operator>=(const i4& other) const { return v >= other.v; } + + bool operator==(const int64_t other) const { return v == other; } + bool operator!=(const int64_t other) const { return v != other; } + bool operator<(const int64_t other) const { return v < other; } + bool operator>(const int64_t other) const { return v > other; } + bool operator<=(const int64_t other) const { return v <= other; } + bool operator>=(const int64_t other) const { return v >= other; } + + i4& operator++() { + v = (v + 1) & 0x0F; + return *this; + } + + friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) { + os << static_cast(num.v); + return os; + } + + std::string ToString() const { + std::ostringstream os; + os << static_cast(v); + return os.str(); + } +}; + +using int4 = i4; +using uint4 = i4; + +} // namespace ml_dtypes + +#endif // ML_DTYPES_INT4_H_ diff --git a/ml_dtypes/tests/int4_test.cc b/ml_dtypes/tests/int4_test.cc new file mode 100644 index 00000000..53d35157 --- /dev/null +++ b/ml_dtypes/tests/int4_test.cc @@ -0,0 +1,397 @@ +/* Copyright 2023 The ml_dtypes Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "include/int4.h" + +#include +#include +#include +#include +#include +#include + +#include +#include "Eigen/Core" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace ml_dtypes { +namespace { + +template +class Int4Test : public ::testing::Test {}; + +// Helper utility for prettier test names. +struct Int4TestParamNames { + template + static std::string GetName(int idx) { + if constexpr (std::is_same_v) { + return "int4"; + } else if constexpr (std::is_same_v) { + return "uint4"; + } + return std::to_string(idx); + } +}; + +using Int4Types = ::testing::Types; +TYPED_TEST_SUITE(Int4Test, Int4Types, Int4TestParamNames); + +TEST(Int4Test, NumericLimits) { + EXPECT_EQ(std::numeric_limits::is_signed, true); + EXPECT_EQ(std::numeric_limits::is_modulo, false); + EXPECT_EQ(static_cast(std::numeric_limits::min()), -8); + EXPECT_EQ(static_cast(std::numeric_limits::max()), 7); + EXPECT_EQ(static_cast(std::numeric_limits::lowest()), -8); + EXPECT_EQ(std::numeric_limits::digits, 3); + EXPECT_EQ(std::numeric_limits::digits10, 0); +} + +TEST(UInt4Test, NumericLimits) { + EXPECT_EQ(std::numeric_limits::is_signed, false); + EXPECT_EQ(std::numeric_limits::is_modulo, true); + EXPECT_EQ(static_cast(std::numeric_limits::min()), 0); + EXPECT_EQ(static_cast(std::numeric_limits::max()), 15); + EXPECT_EQ(static_cast(std::numeric_limits::lowest()), 0); + EXPECT_EQ(std::numeric_limits::digits, 4); + EXPECT_EQ(std::numeric_limits::digits10, 1); +} + +TYPED_TEST(Int4Test, NumericLimitsBase) { + using Int4 = TypeParam; + EXPECT_EQ(std::numeric_limits::is_specialized, true); + EXPECT_EQ(std::numeric_limits::is_integer, true); + EXPECT_EQ(std::numeric_limits::is_exact, true); + EXPECT_EQ(std::numeric_limits::has_infinity, false); + EXPECT_EQ(std::numeric_limits::has_quiet_NaN, false); + EXPECT_EQ(std::numeric_limits::has_signaling_NaN, false); + EXPECT_EQ(std::numeric_limits::has_denorm, std::denorm_absent); + EXPECT_EQ(std::numeric_limits::has_denorm_loss, false); + EXPECT_EQ(std::numeric_limits::round_style, std::round_toward_zero); + EXPECT_EQ(std::numeric_limits::is_iec559, false); + EXPECT_EQ(std::numeric_limits::is_bounded, true); + EXPECT_EQ(std::numeric_limits::radix, 2); + EXPECT_EQ(std::numeric_limits::min_exponent, 0); + EXPECT_EQ(std::numeric_limits::min_exponent10, 0); + EXPECT_EQ(std::numeric_limits::max_exponent, 0); + EXPECT_EQ(std::numeric_limits::max_exponent10, 0); + EXPECT_EQ(std::numeric_limits::traps, true); + EXPECT_EQ(std::numeric_limits::tinyness_before, false); + EXPECT_EQ(static_cast(std::numeric_limits::epsilon()), 0); + EXPECT_EQ(static_cast(std::numeric_limits::round_error()), 0); + EXPECT_EQ(static_cast(std::numeric_limits::infinity()), 0); + EXPECT_EQ(static_cast(std::numeric_limits::quiet_NaN()), 0); + EXPECT_EQ(static_cast(std::numeric_limits::signaling_NaN()), 0); + EXPECT_EQ(static_cast(std::numeric_limits::denorm_min()), 0); +} + +TYPED_TEST(Int4Test, CreateAndAssign) { + using Int4 = TypeParam; + + // Constructors. + EXPECT_EQ(Int4(), Int4(0)); + Int4 a(1); + EXPECT_EQ(a, Int4(1)); + Int4 b(std::move(a)); + EXPECT_EQ(b, Int4(1)); + + // Assignments. + EXPECT_EQ(a = Int4(2), Int4(2)); + EXPECT_EQ(b = a, Int4(2)); + EXPECT_EQ((a = Int4(3), b = std::move(a)), Int4(3)); +} + +// To ensure an expression is evaluated in a constexpr context, +// we use the trick of inserting the expression in a template +// parameter. +template +struct ConstexprEvaluator { + static constexpr bool val = true; +}; + +// To avoid warnings about unused left-side of comma expressions, +// we additionally pass the expression through a contexpr function. +template +constexpr void ConstexprEvaluatorFunc(T&&){}; + +#define TEST_CONSTEXPR(expr) \ + do { \ + EXPECT_TRUE((ConstexprEvaluator<(ConstexprEvaluatorFunc(expr), 1)>::val)); \ + } while (false) + +TYPED_TEST(Int4Test, Constexpr) { + TEST_CONSTEXPR(int4(0)); + TEST_CONSTEXPR(static_cast(int4(0))); + + TEST_CONSTEXPR(-int4(1)); + TEST_CONSTEXPR(int4(0) + int4(1)); + TEST_CONSTEXPR(int4(1) - int4(0)); + TEST_CONSTEXPR(int4(0) * int4(1)); + TEST_CONSTEXPR(int4(0) / int4(1)); + TEST_CONSTEXPR(int4(0) % int4(1)); + + TEST_CONSTEXPR(int4(1) & int4(0xF)); + TEST_CONSTEXPR(int4(1) | int4(0xF)); + TEST_CONSTEXPR(int4(1) ^ int4(0xF)); + TEST_CONSTEXPR(~int4(1)); + TEST_CONSTEXPR(int4(1) >> 1); + TEST_CONSTEXPR(int4(1) << 1); + + TEST_CONSTEXPR(int4(1) == int4(1)); + TEST_CONSTEXPR(int4(1) != int4(1)); + TEST_CONSTEXPR(int4(1) < int4(1)); + TEST_CONSTEXPR(int4(1) > int4(1)); + TEST_CONSTEXPR(int4(1) <= int4(1)); + TEST_CONSTEXPR(int4(1) >= int4(1)); + + TEST_CONSTEXPR(++int4(1)); + TEST_CONSTEXPR(int4(1)++); + TEST_CONSTEXPR(--int4(1)); + TEST_CONSTEXPR(int4(1)--); + + TEST_CONSTEXPR(int4(1) += int4(2)); + TEST_CONSTEXPR(int4(1) -= int4(2)); + TEST_CONSTEXPR(int4(1) *= int4(2)); + TEST_CONSTEXPR(int4(1) /= int4(2)); + TEST_CONSTEXPR(int4(1) %= int4(2)); + TEST_CONSTEXPR(int4(1) &= int4(2)); + TEST_CONSTEXPR(int4(1) |= int4(2)); + TEST_CONSTEXPR(int4(1) ^= int4(2)); + TEST_CONSTEXPR(int4(1) >>= 1); + TEST_CONSTEXPR(int4(1) <<= 1); +} + +TYPED_TEST(Int4Test, Casts) { + using Int4 = TypeParam; + + // Explicit integer types. + EXPECT_EQ(static_cast(Int4(4)), 4); + EXPECT_EQ(static_cast(Int4(5)), 5); + EXPECT_EQ(static_cast(Int4(6)), 6); + EXPECT_EQ(static_cast(Int4(7)), 7); + EXPECT_EQ(static_cast(Int4(1)), 1); + + // Implicit conversion to optional. + std::optional c = Int4(2); + EXPECT_EQ(c, 2); + + // Loop through all valid values. + for (int i = static_cast(std::numeric_limits::min()); + i <= static_cast(std::numeric_limits::max()); ++i) { + // Round-trip. + EXPECT_EQ(static_cast(Int4(i)), i); + + // Float truncation. + for (int j = 1; j < 10; ++j) { + float offset = -1.f + j * 1.f / 5; + float f = i + offset; + EXPECT_EQ(Int4(f), Int4(static_cast(f))); + } + } +} + +TYPED_TEST(Int4Test, Operators) { + using Int4 = TypeParam; + for (int i = static_cast(std::numeric_limits::min()); + i <= static_cast(std::numeric_limits::max()); ++i) { + Int4 x = Int4(i); + + EXPECT_EQ(-x, Int4(-i)); + EXPECT_EQ(~x, Int4(~i)); + Int4 a; + EXPECT_EQ((a = x, ++a), Int4(i + 1)); + EXPECT_EQ(a, Int4(i + 1)); + EXPECT_EQ((a = x, a++), Int4(i)); + EXPECT_EQ(a, Int4(i + 1)); + EXPECT_EQ((a = x, --a), Int4(i - 1)); + EXPECT_EQ(a, Int4(i - 1)); + EXPECT_EQ((a = x, a--), Int4(i)); + EXPECT_EQ(a, Int4(i - 1)); + + for (int j = static_cast(std::numeric_limits::min()); + j <= static_cast(std::numeric_limits::max()); ++j) { + Int4 y = Int4(j); + + EXPECT_EQ(x + y, Int4(i + j)); + EXPECT_EQ(x - y, Int4(i - j)); + EXPECT_EQ(x * y, Int4(i * j)); + if (j != 0) { + EXPECT_EQ(x / y, Int4(i / j)); + EXPECT_EQ(x % y, Int4(i % j)); + } + EXPECT_EQ(x & y, Int4(i & j)); + EXPECT_EQ(x | y, Int4(i | j)); + EXPECT_EQ(x ^ y, Int4(i ^ j)); + + EXPECT_EQ(x == y, i == j); + EXPECT_EQ(x != y, i != j); + EXPECT_EQ(x < y, i < j); + EXPECT_EQ(x > y, i > j); + EXPECT_EQ(x <= y, i <= j); + EXPECT_EQ(x >= y, i >= j); + + EXPECT_EQ(x == static_cast(j), i == j); + EXPECT_EQ(x != static_cast(j), i != j); + EXPECT_EQ(x < static_cast(j), i < j); + EXPECT_EQ(x > static_cast(j), i > j); + EXPECT_EQ(x <= static_cast(j), i <= j); + EXPECT_EQ(x >= static_cast(j), i >= j); + + EXPECT_EQ(static_cast(j) == x, j == i); + EXPECT_EQ(static_cast(j) != x, j != i); + EXPECT_EQ(static_cast(j) < x, j < i); + EXPECT_EQ(static_cast(j) > x, j > i); + EXPECT_EQ(static_cast(j) <= x, j <= i); + EXPECT_EQ(static_cast(j) >= x, j >= i); + + EXPECT_EQ((a = x, a += y), Int4(i + j)); + EXPECT_EQ((a = x, a -= y), Int4(i - j)); + EXPECT_EQ((a = x, a *= y), Int4(i * j)); + if (j != 0) { + EXPECT_EQ((a = x, a /= y), Int4(i / j)); + EXPECT_EQ((a = x, a %= y), Int4(i % j)); + } + EXPECT_EQ((a = x, a &= y), Int4(i & j)); + EXPECT_EQ((a = x, a |= y), Int4(i | j)); + EXPECT_EQ((a = x, a ^= y), Int4(i ^ j)); + } + + for (int amount = 0; amount < 4; ++amount) { + EXPECT_EQ(x >> amount, Int4(i >> amount)); + EXPECT_EQ(x << amount, Int4(i << amount)); + EXPECT_EQ((a = x, a >>= amount), Int4(i >> amount)); + EXPECT_EQ((a = x, a <<= amount), Int4(i << amount)); + } + } +} + +TYPED_TEST(Int4Test, ToString) { + using Int4 = TypeParam; + for (int i = static_cast(std::numeric_limits::min()); + i <= static_cast(std::numeric_limits::max()); ++i) { + Int4 x = Int4(i); + std::stringstream ss; + ss << x; + EXPECT_EQ(ss.str(), std::to_string(i)); + EXPECT_EQ(x.ToString(), std::to_string(i)); + } +} + +#define GEN_DEST_TYPES(Type) \ + std::pair, std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair + +#define GEN_TYPE_PAIRS() GEN_DEST_TYPES(int4), GEN_DEST_TYPES(uint4) + +using Int4CastTypePairs = ::testing::Types; +template +class Int4CastTest : public ::testing::Test {}; + +// Helper utility for prettier test names. +struct Int4CastTestParamNames { + template + static std::string GetName(int idx) { + using first_type = typename TypeParam::first_type; + using second_type = typename TypeParam::second_type; + return ::testing::internal::GetTypeName() + "_" + + ::testing::internal::GetTypeName(); + } +}; + +TYPED_TEST_SUITE(Int4CastTest, Int4CastTypePairs, Int4CastTestParamNames); + +TYPED_TEST(Int4CastTest, CastThroughInt) { + using Int4 = typename TypeParam::first_type; + using DestType = typename TypeParam::second_type; + + for (int i = 0x00; i <= 0x0F; ++i) { + Int4 x = Int4(i); + DestType dest = static_cast(x); + DestType expected = static_cast(static_cast(x)); + EXPECT_EQ(dest, expected); + } +} + +TYPED_TEST(Int4CastTest, DeviceCast) { + using Int4 = typename TypeParam::first_type; + using DestType = typename TypeParam::second_type; + +#if defined(EIGEN_USE_GPU) + Eigen::GpuStreamDevice stream; + Eigen::GpuDevice device(&stream); +#elif defined(EIGEN_USE_THREADS) + constexpr int kThreads = 4; + Eigen::ThreadPool tp(kThreads); + Eigen::ThreadPoolDevice device(&tp, kThreads); +#else + Eigen::DefaultDevice device; +#endif + + const int kNumElems = 256; + // Allocate device buffers and create device tensors. + Int4* src_device_buffer = (Int4*)device.allocate(kNumElems * sizeof(Int4)); + DestType* dst_device_buffer = + (DestType*)device.allocate(kNumElems * sizeof(DestType)); + + Eigen::TensorMap, Eigen::Aligned> src_device( + src_device_buffer, kNumElems); + Eigen::TensorMap, Eigen::Aligned> dst_device( + dst_device_buffer, kNumElems); + + // Allocate host buffers and initially src memory. + Eigen::Tensor src_cpu(kNumElems); + Eigen::Tensor dst_cpu(kNumElems); + for (int i = 0; i < kNumElems; ++i) { + src_cpu(i) = Eigen::numext::bit_cast(static_cast(i)); + } + + // Transfer data to device, perform a cast to DestType, then transfer result + // back to host. + device.memcpyHostToDevice(src_device_buffer, src_cpu.data(), + kNumElems * sizeof(Int4)); + dst_device.device(device) = src_device.template cast(); + device.memcpyDeviceToHost(dst_cpu.data(), dst_device_buffer, + kNumElems * sizeof(DestType)); + device.synchronize(); + + for (int i = 0; i < kNumElems; ++i) { + DestType expected = static_cast(src_cpu(i)); + EXPECT_EQ(dst_cpu(i), expected); + } + + // Cast back from DestType to Int4. + // First clear out the device src buffer, since that will be the destination. + src_cpu.setZero(); + device.memcpyHostToDevice(src_device_buffer, src_cpu.data(), + kNumElems * sizeof(Int4)); + src_device.device(device) = dst_device.template cast(); + device.memcpyDeviceToHost(src_cpu.data(), src_device_buffer, + kNumElems * sizeof(Int4)); + device.synchronize(); + + for (int i = 0; i < kNumElems; ++i) { + Int4 expected = static_cast(dst_cpu(i)); + EXPECT_EQ(src_cpu(i), expected); + } + + // Clean up. + device.deallocate(src_device_buffer); + device.deallocate(dst_device_buffer); + device.synchronize(); +} + +} // namespace +} // namespace ml_dtypes \ No newline at end of file diff --git a/setup.py b/setup.py index 38b7a1d0..a0e9b205 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ def find_package_modules(self, package, package_dir): setup( ext_modules=[ Pybind11Extension( - "ml_dtypes._custom_floats", + "ml_dtypes._ml_dtypes_lib", [ "ml_dtypes/_src/dtypes.cc", "ml_dtypes/_src/numpy.cc",