diff --git a/ml_dtypes/__init__.py b/ml_dtypes/__init__.py index f22cfb12..2c7d1aca 100644 --- a/ml_dtypes/__init__.py +++ b/ml_dtypes/__init__.py @@ -29,17 +29,16 @@ from typing import Type -from ml_dtypes._custom_floats import bfloat16 -from ml_dtypes._custom_floats import float8_e4m3b11fnuz -from ml_dtypes._custom_floats import float8_e4m3fn -from ml_dtypes._custom_floats import float8_e4m3fnuz -from ml_dtypes._custom_floats import float8_e5m2 -from ml_dtypes._custom_floats import float8_e5m2fnuz -from ml_dtypes._custom_floats import int4 -from ml_dtypes._custom_floats import uint4 from ml_dtypes._finfo import finfo from ml_dtypes._iinfo import iinfo - +from ml_dtypes._ml_dtypes_ext import bfloat16 +from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz +from ml_dtypes._ml_dtypes_ext import float8_e4m3fn +from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz +from ml_dtypes._ml_dtypes_ext import float8_e5m2 +from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz +from ml_dtypes._ml_dtypes_ext import int4 +from ml_dtypes._ml_dtypes_ext import uint4 import numpy as np bfloat16: Type[np.generic] diff --git a/ml_dtypes/_finfo.py b/ml_dtypes/_finfo.py index 1a78bd63..451f2766 100644 --- a/ml_dtypes/_finfo.py +++ b/ml_dtypes/_finfo.py @@ -16,13 +16,12 @@ from typing import Dict -from ml_dtypes._custom_floats import bfloat16 -from ml_dtypes._custom_floats import float8_e4m3b11fnuz -from ml_dtypes._custom_floats import float8_e4m3fn -from ml_dtypes._custom_floats import float8_e4m3fnuz -from ml_dtypes._custom_floats import float8_e5m2 -from ml_dtypes._custom_floats import float8_e5m2fnuz - +from ml_dtypes._ml_dtypes_ext import bfloat16 +from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz +from ml_dtypes._ml_dtypes_ext import float8_e4m3fn +from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz +from ml_dtypes._ml_dtypes_ext import float8_e5m2 +from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz import numpy as np _bfloat16_dtype = np.dtype(bfloat16) diff --git a/ml_dtypes/_iinfo.py b/ml_dtypes/_iinfo.py index 1c27412b..eb9f0943 100644 --- a/ml_dtypes/_iinfo.py +++ b/ml_dtypes/_iinfo.py @@ -14,9 +14,8 @@ """Overload of numpy.iinfo to handle dtypes defined in ml_dtypes.""" -from ml_dtypes._custom_floats import int4 -from ml_dtypes._custom_floats import uint4 - +from ml_dtypes._ml_dtypes_ext import int4 +from ml_dtypes._ml_dtypes_ext import uint4 import numpy as np _int4_dtype = np.dtype(int4) diff --git a/ml_dtypes/_src/dtypes.cc b/ml_dtypes/_src/dtypes.cc index c0d30008..bfd9ca35 100644 --- a/ml_dtypes/_src/dtypes.cc +++ b/ml_dtypes/_src/dtypes.cc @@ -31,8 +31,9 @@ limitations under the License. #include "Eigen/Core" #include "_src/custom_float.h" -#include "_src/int4.h" +#include "_src/int4_numpy.h" #include "include/float8.h" +#include "include/int4.h" namespace ml_dtypes { @@ -297,7 +298,7 @@ bool Initialize() { static PyModuleDef module_def = { PyModuleDef_HEAD_INIT, - "_custom_floats", + "_ml_dtypes_ext", }; // TODO(phawkins): PyMODINIT_FUNC handles visibility correctly in Python 3.9+. @@ -308,14 +309,14 @@ static PyModuleDef module_def = { #define EXPORT_SYMBOL __attribute__((visibility("default"))) #endif -extern "C" EXPORT_SYMBOL PyObject* PyInit__custom_floats() { +extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() { Safe_PyObjectPtr m = make_safe(PyModule_Create(&module_def)); if (!m) { return nullptr; } if (!Initialize()) { if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_RuntimeError, "cannot load _custom_floats module."); + PyErr_SetString(PyExc_RuntimeError, "cannot load _ml_dtypes_ext module."); } return nullptr; } diff --git a/ml_dtypes/_src/int4.h b/ml_dtypes/_src/int4_numpy.h similarity index 90% rename from ml_dtypes/_src/int4.h rename to ml_dtypes/_src/int4_numpy.h index ce855eb5..7f23fbc1 100644 --- a/ml_dtypes/_src/int4.h +++ b/ml_dtypes/_src/int4_numpy.h @@ -13,95 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef ML_DTYPES_INT4_H_ -#define ML_DTYPES_INT4_H_ +#ifndef ML_DTYPES_INT4_NUMPY_H_ +#define ML_DTYPES_INT4_NUMPY_H_ // Must be included first // clang-format off #include "_src/numpy.h" // clang-format on -#include //NOLINT -#include //NOLINT -#include //NOLINT -#include //NOLINT - #include "Eigen/Core" #include "_src/common.h" // NOLINT #include "_src/ufuncs.h" // NOLINT +#include "include/int4.h" namespace ml_dtypes { -template -struct i4 { - private: - UnderlyingTy v : 4; - - public: - i4() : v(0) {} - explicit i4(UnderlyingTy val) : v(val & 0x0F) {} - template - explicit i4(T t) : i4(static_cast(t)) {} - i4(const i4& other) = default; - - static constexpr i4 lowest() { - return std::is_signed::value ? i4(-8) : i4(0); - } - static constexpr i4 highest() { - return std::is_signed::value ? i4(7) : i4(15); - } - - template >> - explicit operator T() const { - return static_cast(v); - } - // NOLINTNEXTLINE(google-explicit-constructor) - operator std::optional() const { return static_cast(v); } - - i4 operator-() const { return i4(-v); } - i4 operator+(const i4& other) const { return i4((v + other.v)); } - i4 operator-(const i4& other) const { return i4((v - other.v)); } - i4 operator*(const i4& other) const { return i4((v * other.v)); } - i4 operator/(const i4& other) const { return i4((v / other.v)); } - i4 operator%(const i4& other) const { return i4((v % other.v)); } - - i4 operator>>(const int amount) const { return i4((v >> amount)); } - i4 operator<<(const int amount) const { return i4((v << amount)); } - - bool operator==(const i4& other) const { return v == other.v; } - bool operator!=(const i4& other) const { return v != other.v; } - bool operator<(const i4& other) const { return v < other.v; } - bool operator>(const i4& other) const { return v > other.v; } - bool operator<=(const i4& other) const { return v <= other.v; } - bool operator>=(const i4& other) const { return v >= other.v; } - - bool operator==(const int64_t other) const { return v == other; } - bool operator!=(const int64_t other) const { return v != other; } - bool operator<(const int64_t other) const { return v < other; } - bool operator>(const int64_t other) const { return v > other; } - bool operator<=(const int64_t other) const { return v <= other; } - bool operator>=(const int64_t other) const { return v >= other; } - - i4& operator++() { - v = (v + 1) & 0x0F; - return *this; - } - - friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) { - os << static_cast(num.v); - return os; - } - - std::string ToString() const { - std::ostringstream os; - os << static_cast(v); - return os.str(); - } -}; - -using int4 = i4; -using uint4 = i4; - template struct Int4TypeDescriptor { static int Dtype() { return npy_type; } @@ -878,4 +804,4 @@ bool RegisterInt4Dtype(PyObject* numpy) { } // namespace ml_dtypes -#endif // ML_DTYPES_INT4_H_ +#endif // ML_DTYPES_INT4_NUMPY_H_ diff --git a/ml_dtypes/include/int4.h b/ml_dtypes/include/int4.h new file mode 100644 index 00000000..51b2c2d5 --- /dev/null +++ b/ml_dtypes/include/int4.h @@ -0,0 +1,99 @@ +/* Copyright 2023 The ml_dtypes Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef ML_DTYPES_INT4_H_ +#define ML_DTYPES_INT4_H_ + +#include +#include +#include +#include +#include + +namespace ml_dtypes { + +template +struct i4 { + private: + UnderlyingTy v : 4; + + public: + i4() : v(0) {} + explicit i4(UnderlyingTy val) : v(val & 0x0F) {} + template + explicit i4(T t) : i4(static_cast(t)) {} + i4(const i4& other) = default; + + static constexpr i4 lowest() { + return std::is_signed::value ? i4(-8) : i4(0); + } + static constexpr i4 highest() { + return std::is_signed::value ? i4(7) : i4(15); + } + + template >> + explicit operator T() const { + return static_cast(v); + } + // NOLINTNEXTLINE(google-explicit-constructor) + operator std::optional() const { return static_cast(v); } + + i4 operator-() const { return i4(-v); } + i4 operator+(const i4& other) const { return i4((v + other.v)); } + i4 operator-(const i4& other) const { return i4((v - other.v)); } + i4 operator*(const i4& other) const { return i4((v * other.v)); } + i4 operator/(const i4& other) const { return i4((v / other.v)); } + i4 operator%(const i4& other) const { return i4((v % other.v)); } + + i4 operator>>(const int amount) const { return i4((v >> amount)); } + i4 operator<<(const int amount) const { return i4((v << amount)); } + + bool operator==(const i4& other) const { return v == other.v; } + bool operator!=(const i4& other) const { return v != other.v; } + bool operator<(const i4& other) const { return v < other.v; } + bool operator>(const i4& other) const { return v > other.v; } + bool operator<=(const i4& other) const { return v <= other.v; } + bool operator>=(const i4& other) const { return v >= other.v; } + + bool operator==(const int64_t other) const { return v == other; } + bool operator!=(const int64_t other) const { return v != other; } + bool operator<(const int64_t other) const { return v < other; } + bool operator>(const int64_t other) const { return v > other; } + bool operator<=(const int64_t other) const { return v <= other; } + bool operator>=(const int64_t other) const { return v >= other; } + + i4& operator++() { + v = (v + 1) & 0x0F; + return *this; + } + + friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) { + os << static_cast(num.v); + return os; + } + + std::string ToString() const { + std::ostringstream os; + os << static_cast(v); + return os.str(); + } +}; + +using int4 = i4; +using uint4 = i4; + +} // namespace ml_dtypes + +#endif // ML_DTYPES_INT4_H_ diff --git a/setup.py b/setup.py index 38b7a1d0..a0e9b205 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ def find_package_modules(self, package, package_dir): setup( ext_modules=[ Pybind11Extension( - "ml_dtypes._custom_floats", + "ml_dtypes._ml_dtypes_lib", [ "ml_dtypes/_src/dtypes.cc", "ml_dtypes/_src/numpy.cc",