From 3fba583d223ebd12da8510295920be89cae84cd1 Mon Sep 17 00:00:00 2001 From: Daniel Ng Date: Tue, 19 Sep 2023 14:25:36 -0700 Subject: [PATCH] Add float8 numpy binding PiperOrigin-RevId: 566743428 --- ml_dtypes/_src/int4_numpy.h | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/ml_dtypes/_src/int4_numpy.h b/ml_dtypes/_src/int4_numpy.h index 7f23fbc1..fb072625 100644 --- a/ml_dtypes/_src/int4_numpy.h +++ b/ml_dtypes/_src/int4_numpy.h @@ -712,6 +712,20 @@ bool RegisterInt4UFuncs(PyObject* numpy) { template bool RegisterInt4Dtype(PyObject* numpy) { + int typenum = + PyArray_TypeNumFromName(const_cast(TypeDescriptor::kTypeName)); + if (typenum != NPY_NOTYPE) { + PyArray_Descr* descr = PyArray_DescrFromType(typenum); + // The test for an argmax function here is to verify that the + // (u)int4 implementation is sufficiently new, and, say, not from + // an older version of TF or JAX. + if (descr && descr->f && descr->f->argmax) { + TypeDescriptor::npy_type = typenum; + TypeDescriptor::type_ptr = reinterpret_cast(descr->typeobj); + return true; + } + } + Safe_PyObjectPtr name = make_safe(PyUnicode_FromString(TypeDescriptor::kTypeName)); Safe_PyObjectPtr qualname =