Skip to content

Commit

Permalink
Add float8 & int4 numpy integration
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 567178215
  • Loading branch information
ChromeHearts authored and The ml_dtypes Authors committed Sep 21, 2023
1 parent fc69958 commit eac12ac
Showing 1 changed file with 63 additions and 3 deletions.
66 changes: 63 additions & 3 deletions ml_dtypes/_src/int4_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ limitations under the License.

namespace ml_dtypes {

constexpr char kOutOfRange[] = "out of range value cannot be converted to int4";

template <typename T>
struct Int4TypeDescriptor {
static int Dtype() { return npy_type; }
Expand Down Expand Up @@ -114,8 +116,7 @@ bool CastToInt4(PyObject* arg, T* output) {
}
if (d < static_cast<double>(T::lowest()) ||
d > static_cast<double>(T::highest())) {
PyErr_SetString(PyExc_OverflowError,
"out of range value cannot be converted to int4");
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
}
*output = T(d);
return true;
Expand All @@ -131,9 +132,37 @@ bool CastToInt4(PyObject* arg, T* output) {
if (PyArray_IsScalar(arg, Integer)) {
int64_t v;
PyArray_CastScalarToCtype(arg, &v, PyArray_DescrFromType(NPY_INT64));

if (!(std::numeric_limits<T>::min() <= v &&
v <= std::numeric_limits<T>::max())) {
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
return false;
}
*output = T(v);
return true;
}
if (PyArray_IsScalar(arg, Float)) {
float f;
PyArray_ScalarAsCtype(arg, &f);
if (!(std::numeric_limits<T>::min() <= f &&
f <= std::numeric_limits<T>::max())) {
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
return false;
}
*output = T(static_cast<::int8_t>(f));
return true;
}
if (PyArray_IsScalar(arg, Double)) {
double d;
PyArray_ScalarAsCtype(arg, &d);
if (!(std::numeric_limits<T>::min() <= d &&
d <= std::numeric_limits<T>::max())) {
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
return false;
}
*output = T(static_cast<::int8_t>(d));
return true;
}
return false;
}

Expand Down Expand Up @@ -652,7 +681,38 @@ bool RegisterInt4Casts() {
}

// Safe casts from T to other types
// TODO(phawkins): add integer types
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_INT8,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_UINT8,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_INT16,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_UINT16,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_INT32,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_UINT32,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_INT64,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_UINT64,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_FLOAT,
NPY_NOSCALAR) < 0) {
return false;
Expand Down

0 comments on commit eac12ac

Please sign in to comment.