Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add float8 & int4 numpy integration #103

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]

## [0.3.1] - 2023-09-22

* Added support for int4 casting to wider integers such as int8
* Addes support to cast np.float32 and np.float64 into int4

## [0.3.0] - 2023-09-19

* Dropped support for Python 3.8, following [NEP 29].
Expand All @@ -44,7 +49,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

* Initial release

[Unreleased]: https://github.com/jax-ml/ml_dtypes/compare/v0.3.0...HEAD
[Unreleased]: https://github.com/jax-ml/ml_dtypes/compare/v0.3.1...HEAD
[0.3.1]: https://github.com/jax-ml/ml_dtypes/compare/v0.3.0...v0.3.1
[0.3.0]: https://github.com/jax-ml/ml_dtypes/compare/v0.2.0...v0.3.0
[0.2.0]: https://github.com/jax-ml/ml_dtypes/compare/v0.1.0...v0.2.0
[0.1.0]: https://github.com/jax-ml/ml_dtypes/releases/tag/v0.1.0
Expand Down
2 changes: 1 addition & 1 deletion ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = '0.3.0' # Keep in sync with pyproject.toml:version
__version__ = '0.3.1' # Keep in sync with pyproject.toml:version
__all__ = [
'__version__',
'bfloat16',
Expand Down
66 changes: 63 additions & 3 deletions ml_dtypes/_src/int4_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ limitations under the License.

namespace ml_dtypes {

constexpr char kOutOfRange[] = "out of range value cannot be converted to int4";

template <typename T>
struct Int4TypeDescriptor {
static int Dtype() { return npy_type; }
Expand Down Expand Up @@ -114,8 +116,7 @@ bool CastToInt4(PyObject* arg, T* output) {
}
if (d < static_cast<double>(T::lowest()) ||
d > static_cast<double>(T::highest())) {
PyErr_SetString(PyExc_OverflowError,
"out of range value cannot be converted to int4");
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
}
*output = T(d);
return true;
Expand All @@ -131,9 +132,37 @@ bool CastToInt4(PyObject* arg, T* output) {
if (PyArray_IsScalar(arg, Integer)) {
int64_t v;
PyArray_CastScalarToCtype(arg, &v, PyArray_DescrFromType(NPY_INT64));

if (!(std::numeric_limits<T>::min() <= v &&
v <= std::numeric_limits<T>::max())) {
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
return false;
}
*output = T(v);
return true;
}
if (PyArray_IsScalar(arg, Float)) {
float f;
PyArray_ScalarAsCtype(arg, &f);
if (!(std::numeric_limits<T>::min() <= f &&
f <= std::numeric_limits<T>::max())) {
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
return false;
}
*output = T(static_cast<::int8_t>(f));
return true;
}
if (PyArray_IsScalar(arg, Double)) {
double d;
PyArray_ScalarAsCtype(arg, &d);
if (!(std::numeric_limits<T>::min() <= d &&
d <= std::numeric_limits<T>::max())) {
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
return false;
}
*output = T(static_cast<::int8_t>(d));
return true;
}
return false;
}

Expand Down Expand Up @@ -652,7 +681,38 @@ bool RegisterInt4Casts() {
}

// Safe casts from T to other types
// TODO(phawkins): add integer types
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_INT8,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_UINT8,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_INT16,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_UINT16,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_INT32,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_UINT32,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_INT64,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_UINT64,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_FLOAT,
NPY_NOSCALAR) < 0) {
return false;
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ml_dtypes"
version = "0.3.0" # Keep in sync with ml_dtypes/__init__.py:__version__
version = "0.3.1" # Keep in sync with ml_dtypes/__init__.py:__version__
description = ""
readme = "README.md"
requires-python = ">=3.9"
Expand Down
Loading