Skip to content

Commit

Permalink
Switch ml_dtypes to use the Python Stable ABI.
Browse files Browse the repository at this point in the history
This allows us to build a single wheel using Python 3.9 and deploy it on
all platforms, with the exception of 3.13t.

The main downside of this change is it made the hash function for
scalars slightly more expensive.
  • Loading branch information
hawkinsp committed Sep 17, 2024
1 parent b65a1f6 commit daf3c1a
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 64 deletions.
51 changes: 26 additions & 25 deletions ml_dtypes/_src/custom_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ template <typename T>
Safe_PyObjectPtr PyCustomFloat_FromT(T x) {
PyTypeObject* type =
reinterpret_cast<PyTypeObject*>(TypeDescriptor<T>::type_ptr);
Safe_PyObjectPtr ref = make_safe(type->tp_alloc(type, 0));
Safe_PyObjectPtr ref = make_safe(PyObject_New(PyObject, type));
PyCustomFloat<T>* p = reinterpret_cast<PyCustomFloat<T>*>(ref.get());
if (p) {
p->value = x;
Expand Down Expand Up @@ -213,7 +213,9 @@ PyObject* PyCustomFloat_Add(PyObject* a, PyObject* b) {
if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
return PyCustomFloat_FromT<T>(x + y).release();
}
return PyArray_Type.tp_as_number->nb_add(a, b);
auto array_nb_add =
reinterpret_cast<binaryfunc>(PyType_GetSlot(&PyArray_Type, Py_nb_add));
return array_nb_add(a, b);
}

template <typename T>
Expand All @@ -222,7 +224,9 @@ PyObject* PyCustomFloat_Subtract(PyObject* a, PyObject* b) {
if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
return PyCustomFloat_FromT<T>(x - y).release();
}
return PyArray_Type.tp_as_number->nb_subtract(a, b);
auto array_nb_subtract = reinterpret_cast<binaryfunc>(
PyType_GetSlot(&PyArray_Type, Py_nb_subtract));
return array_nb_subtract(a, b);
}

template <typename T>
Expand All @@ -231,7 +235,9 @@ PyObject* PyCustomFloat_Multiply(PyObject* a, PyObject* b) {
if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
return PyCustomFloat_FromT<T>(x * y).release();
}
return PyArray_Type.tp_as_number->nb_multiply(a, b);
auto array_nb_multiply = reinterpret_cast<binaryfunc>(
PyType_GetSlot(&PyArray_Type, Py_nb_multiply));
return array_nb_multiply(a, b);
}

template <typename T>
Expand All @@ -240,7 +246,9 @@ PyObject* PyCustomFloat_TrueDivide(PyObject* a, PyObject* b) {
if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
return PyCustomFloat_FromT<T>(x / y).release();
}
return PyArray_Type.tp_as_number->nb_true_divide(a, b);
auto array_nb_true_divide = reinterpret_cast<binaryfunc>(
PyType_GetSlot(&PyArray_Type, Py_nb_true_divide));
return array_nb_true_divide(a, b);
}

// Constructs a new PyCustomFloat.
Expand Down Expand Up @@ -281,8 +289,7 @@ PyObject* PyCustomFloat_New(PyTypeObject* type, PyObject* args,
return PyCustomFloat_FromT<T>(value).release();
}
}
PyErr_Format(PyExc_TypeError, "expected number, got %s",
Py_TYPE(arg)->tp_name);
PyErr_Format(PyExc_TypeError, "expected number, got %R", Py_TYPE(arg));
return nullptr;
}

Expand All @@ -291,7 +298,9 @@ template <typename T>
PyObject* PyCustomFloat_RichCompare(PyObject* a, PyObject* b, int op) {
T x, y;
if (!SafeCastToCustomFloat<T>(a, &x) || !SafeCastToCustomFloat<T>(b, &y)) {
return PyGenericArrType_Type.tp_richcompare(a, b, op);
auto generic_tp_richcompare = reinterpret_cast<richcmpfunc>(
PyType_GetSlot(&PyGenericArrType_Type, Py_tp_richcompare));
return generic_tp_richcompare(a, b, op);
}
bool result;
switch (op) {
Expand Down Expand Up @@ -340,25 +349,18 @@ PyObject* PyCustomFloat_Str(PyObject* self) {
return PyUnicode_FromString(s.str().c_str());
}

// _Py_HashDouble changed its prototype for Python 3.10 so we use an overload to
// handle the two possibilities.
// NOLINTNEXTLINE(clang-diagnostic-unused-function)
inline Py_hash_t HashImpl(Py_hash_t (*hash_double)(PyObject*, double),
PyObject* self, double value) {
return hash_double(self, value);
}

// NOLINTNEXTLINE(clang-diagnostic-unused-function)
inline Py_hash_t HashImpl(Py_hash_t (*hash_double)(double), PyObject* self,
double value) {
return hash_double(value);
}

// Hash function for PyCustomFloat.
template <typename T>
Py_hash_t PyCustomFloat_Hash(PyObject* self) {
T x = reinterpret_cast<PyCustomFloat<T>*>(self)->value;
return HashImpl(&_Py_HashDouble, self, static_cast<double>(x));
if (std::isnan(x)) {
// NaNs hash as the pointer hash of the object.
auto f = reinterpret_cast<hashfunc>(
PyType_GetSlot(&PyBaseObject_Type, Py_tp_hash));
return f(self);
}
Safe_PyObjectPtr f(PyFloat_FromDouble(static_cast<double>(x)));
return PyObject_Hash(f.get());
}

template <typename T>
Expand Down Expand Up @@ -428,8 +430,7 @@ template <typename T>
int NPyCustomFloat_SetItem(PyObject* item, void* data, void* arr) {
T x;
if (!CastToCustomFloat<T>(item, &x)) {
PyErr_Format(PyExc_TypeError, "expected number, got %s",
Py_TYPE(item)->tp_name);
PyErr_Format(PyExc_TypeError, "expected number, got %R", Py_TYPE(item));
return -1;
}
memcpy(data, &x, sizeof(T));
Expand Down
42 changes: 29 additions & 13 deletions ml_dtypes/_src/intn_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ template <typename T>
Safe_PyObjectPtr PyIntN_FromValue(T x) {
PyTypeObject* type =
reinterpret_cast<PyTypeObject*>(TypeDescriptor<T>::type_ptr);
Safe_PyObjectPtr ref = make_safe(type->tp_alloc(type, 0));
Safe_PyObjectPtr ref = make_safe(PyObject_New(PyObject, type));
PyIntN<T>* p = reinterpret_cast<PyIntN<T>*>(ref.get());
if (p) {
p->value = x;
Expand Down Expand Up @@ -214,16 +214,21 @@ PyObject* PyIntN_tp_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
}
} else if (PyUnicode_Check(arg) || PyBytes_Check(arg)) {
// Parse float from string, then cast to T.
PyObject* f = PyLong_FromUnicodeObject(arg, /*base=*/0);
if (PyErr_Occurred()) {
Safe_PyObjectPtr bytes(PyUnicode_AsUTF8String(arg));
if (!bytes) {
return nullptr;
}
PyObject* f =
PyLong_FromString(PyBytes_AsString(bytes.get()), /*end=*/nullptr,
/*base=*/0);
if (!f) {
return nullptr;
}
if (CastToIntN<T>(f, &value)) {
return PyIntN_FromValue<T>(value).release();
}
}
PyErr_Format(PyExc_TypeError, "expected number, got %s",
Py_TYPE(arg)->tp_name);
PyErr_Format(PyExc_TypeError, "expected number, got %R", Py_TYPE(arg));
return nullptr;
}

Expand Down Expand Up @@ -257,7 +262,9 @@ PyObject* PyIntN_nb_add(PyObject* a, PyObject* b) {
if (PyIntN_Value<T>(a, &x) && PyIntN_Value<T>(b, &y)) {
return PyIntN_FromValue<T>(x + y).release();
}
return PyArray_Type.tp_as_number->nb_add(a, b);
auto array_nb_add =
reinterpret_cast<binaryfunc>(PyType_GetSlot(&PyArray_Type, Py_nb_add));
return array_nb_add(a, b);
}

template <typename T>
Expand All @@ -266,7 +273,9 @@ PyObject* PyIntN_nb_subtract(PyObject* a, PyObject* b) {
if (PyIntN_Value<T>(a, &x) && PyIntN_Value<T>(b, &y)) {
return PyIntN_FromValue<T>(x - y).release();
}
return PyArray_Type.tp_as_number->nb_subtract(a, b);
auto array_nb_subtract = reinterpret_cast<binaryfunc>(
PyType_GetSlot(&PyArray_Type, Py_nb_subtract));
return array_nb_subtract(a, b);
}

template <typename T>
Expand All @@ -275,7 +284,9 @@ PyObject* PyIntN_nb_multiply(PyObject* a, PyObject* b) {
if (PyIntN_Value<T>(a, &x) && PyIntN_Value<T>(b, &y)) {
return PyIntN_FromValue<T>(x * y).release();
}
return PyArray_Type.tp_as_number->nb_multiply(a, b);
auto array_nb_multiply = reinterpret_cast<binaryfunc>(
PyType_GetSlot(&PyArray_Type, Py_nb_multiply));
return array_nb_multiply(a, b);
}

template <typename T>
Expand All @@ -292,7 +303,9 @@ PyObject* PyIntN_nb_remainder(PyObject* a, PyObject* b) {
}
return PyIntN_FromValue<T>(v).release();
}
return PyArray_Type.tp_as_number->nb_remainder(a, b);
auto array_nb_remainder = reinterpret_cast<binaryfunc>(
PyType_GetSlot(&PyArray_Type, Py_nb_remainder));
return array_nb_remainder(a, b);
}

template <typename T>
Expand All @@ -309,7 +322,9 @@ PyObject* PyIntN_nb_floor_divide(PyObject* a, PyObject* b) {
}
return PyIntN_FromValue<T>(v).release();
}
return PyArray_Type.tp_as_number->nb_floor_divide(a, b);
auto array_nb_floor_divide = reinterpret_cast<binaryfunc>(
PyType_GetSlot(&PyArray_Type, Py_nb_floor_divide));
return array_nb_floor_divide(a, b);
}

// Implementation of repr() for PyIntN.
Expand Down Expand Up @@ -342,7 +357,9 @@ template <typename T>
PyObject* PyIntN_RichCompare(PyObject* a, PyObject* b, int op) {
T x, y;
if (!PyIntN_Value<T>(a, &x) || !PyIntN_Value<T>(b, &y)) {
return PyGenericArrType_Type.tp_richcompare(a, b, op);
auto generic_tp_richcompare = reinterpret_cast<richcmpfunc>(
PyType_GetSlot(&PyGenericArrType_Type, Py_tp_richcompare));
return generic_tp_richcompare(a, b, op);
}
bool result;
switch (op) {
Expand Down Expand Up @@ -440,8 +457,7 @@ template <typename T>
int NPyIntN_SetItem(PyObject* item, void* data, void* arr) {
T x;
if (!CastToIntN<T>(item, &x)) {
PyErr_Format(PyExc_TypeError, "expected number, got %s",
Py_TYPE(item)->tp_name);
PyErr_Format(PyExc_TypeError, "expected number, got %R", Py_TYPE(item));
return -1;
}
memcpy(data, &x, sizeof(T));
Expand Down
59 changes: 33 additions & 26 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,50 @@

import fnmatch
import platform
import sysconfig

import numpy as np
from setuptools import Extension
from setuptools import setup
from setuptools.command.build_py import build_py as build_py_orig

free_threading = sysconfig.get_config_var("Py_GIL_DISABLED")

if platform.system() == "Windows":
COMPILE_ARGS = [
"/std:c++17",
"/DEIGEN_MPL2_ONLY",
"/EHsc",
"/bigobj",
]
COMPILE_ARGS = [
"/std:c++17",
"/DEIGEN_MPL2_ONLY",
"/EHsc",
"/bigobj",
]
else:
COMPILE_ARGS = [
"-std=c++17",
"-DEIGEN_MPL2_ONLY",
"-fvisibility=hidden",
# -ftrapping-math is necessary because NumPy looks at floating point
# exception state to determine whether to emit, e.g., invalid value
# warnings. Without this setting, on Mac ARM we see spurious "invalid
# value" warnings when running the tests.
"-ftrapping-math",
]
COMPILE_ARGS = [
"-std=c++17",
"-DEIGEN_MPL2_ONLY",
"-fvisibility=hidden",
# -ftrapping-math is necessary because NumPy looks at floating point
# exception state to determine whether to emit, e.g., invalid value
# warnings. Without this setting, on Mac ARM we see spurious "invalid
# value" warnings when running the tests.
"-ftrapping-math",
]
if not free_threading:
COMPILE_ARGS.append("-DPy_LIMITED_API=0x03090000")

exclude = ["third_party*"]


class build_py(build_py_orig): # pylint: disable=invalid-name

def find_package_modules(self, package, package_dir):
modules = super().find_package_modules(package, package_dir)
return [ # pylint: disable=g-complex-comprehension
(pkg, mod, file)
for (pkg, mod, file) in modules
if not any(
fnmatch.fnmatchcase(pkg + "." + mod, pat=pattern)
for pattern in exclude
)
]
def find_package_modules(self, package, package_dir):
modules = super().find_package_modules(package, package_dir)
return [ # pylint: disable=g-complex-comprehension
(pkg, mod, file)
for (pkg, mod, file) in modules
if not any(
fnmatch.fnmatchcase(pkg + "." + mod, pat=pattern) for pattern in exclude
)
]


setup(
Expand All @@ -71,7 +76,9 @@ def find_package_modules(self, package, package_dir):
np.get_include(),
],
extra_compile_args=COMPILE_ARGS,
py_limited_api=not free_threading,
)
],
cmdclass={"build_py": build_py},
options={} if free_threading else {"bdist_wheel": {"py_limited_api": "cp39"}},
)

0 comments on commit daf3c1a

Please sign in to comment.