diff --git a/ml_dtypes/_src/int4_numpy.h b/ml_dtypes/_src/int4_numpy.h index 7f23fbc1..fb072625 100644 --- a/ml_dtypes/_src/int4_numpy.h +++ b/ml_dtypes/_src/int4_numpy.h @@ -712,6 +712,20 @@ bool RegisterInt4UFuncs(PyObject* numpy) { template bool RegisterInt4Dtype(PyObject* numpy) { + int typenum = + PyArray_TypeNumFromName(const_cast(TypeDescriptor::kTypeName)); + if (typenum != NPY_NOTYPE) { + PyArray_Descr* descr = PyArray_DescrFromType(typenum); + // The test for an argmax function here is to verify that the + // (u)int4 implementation is sufficiently new, and, say, not from + // an older version of TF or JAX. + if (descr && descr->f && descr->f->argmax) { + TypeDescriptor::npy_type = typenum; + TypeDescriptor::type_ptr = reinterpret_cast(descr->typeobj); + return true; + } + } + Safe_PyObjectPtr name = make_safe(PyUnicode_FromString(TypeDescriptor::kTypeName)); Safe_PyObjectPtr qualname =