Skip to content

Commit

Permalink
Add float8 numpy binding
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 566743428
  • Loading branch information
ChromeHearts authored and The ml_dtypes Authors committed Sep 20, 2023
1 parent fc69958 commit 3fba583
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions ml_dtypes/_src/int4_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,20 @@ bool RegisterInt4UFuncs(PyObject* numpy) {

template <typename T>
bool RegisterInt4Dtype(PyObject* numpy) {
int typenum =
PyArray_TypeNumFromName(const_cast<char*>(TypeDescriptor<T>::kTypeName));
if (typenum != NPY_NOTYPE) {
PyArray_Descr* descr = PyArray_DescrFromType(typenum);
// The test for an argmax function here is to verify that the
// (u)int4 implementation is sufficiently new, and, say, not from
// an older version of TF or JAX.
if (descr && descr->f && descr->f->argmax) {
TypeDescriptor<T>::npy_type = typenum;
TypeDescriptor<T>::type_ptr = reinterpret_cast<PyObject*>(descr->typeobj);
return true;
}
}

Safe_PyObjectPtr name =
make_safe(PyUnicode_FromString(TypeDescriptor<T>::kTypeName));
Safe_PyObjectPtr qualname =
Expand Down

0 comments on commit 3fba583

Please sign in to comment.