Skip to content

Commit

Permalink
Use PyType_FromSpecWithBases to construct the scalar type objects.
Browse files Browse the repository at this point in the history
This is a simpler and more stable API for manufacturing a type.
  • Loading branch information
hawkinsp committed Sep 13, 2024
1 parent b39b73c commit 013b67d
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 153 deletions.
117 changes: 39 additions & 78 deletions ml_dtypes/_src/custom_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ struct CustomFloatType {
// registered by another system into NumPy.
static PyObject* type_ptr;

static PyNumberMethods number_methods;
static PyType_Spec type_spec;
static PyType_Slot type_slots[];
static PyArray_ArrFuncs arr_funcs;
static PyArray_DescrProto npy_descr_proto;
static PyArray_Descr* npy_descr;
Expand Down Expand Up @@ -242,47 +243,6 @@ PyObject* PyCustomFloat_TrueDivide(PyObject* a, PyObject* b) {
return PyArray_Type.tp_as_number->nb_true_divide(a, b);
}

// Python number methods for PyCustomFloat objects.
template <typename T>
PyNumberMethods CustomFloatType<T>::number_methods = {
PyCustomFloat_Add<T>, // nb_add
PyCustomFloat_Subtract<T>, // nb_subtract
PyCustomFloat_Multiply<T>, // nb_multiply
nullptr, // nb_remainder
nullptr, // nb_divmod
nullptr, // nb_power
PyCustomFloat_Negative<T>, // nb_negative
nullptr, // nb_positive
nullptr, // nb_absolute
nullptr, // nb_nonzero
nullptr, // nb_invert
nullptr, // nb_lshift
nullptr, // nb_rshift
nullptr, // nb_and
nullptr, // nb_xor
nullptr, // nb_or
PyCustomFloat_Int<T>, // nb_int
nullptr, // reserved
PyCustomFloat_Float<T>, // nb_float

nullptr, // nb_inplace_add
nullptr, // nb_inplace_subtract
nullptr, // nb_inplace_multiply
nullptr, // nb_inplace_remainder
nullptr, // nb_inplace_power
nullptr, // nb_inplace_lshift
nullptr, // nb_inplace_rshift
nullptr, // nb_inplace_and
nullptr, // nb_inplace_xor
nullptr, // nb_inplace_or

nullptr, // nb_floor_divide
PyCustomFloat_TrueDivide<T>, // nb_true_divide
nullptr, // nb_inplace_floor_divide
nullptr, // nb_inplace_true_divide
nullptr, // nb_index
};

// Constructs a new PyCustomFloat.
template <typename T>
PyObject* PyCustomFloat_New(PyTypeObject* type, PyObject* args,
Expand Down Expand Up @@ -401,6 +361,34 @@ Py_hash_t PyCustomFloat_Hash(PyObject* self) {
return HashImpl(&_Py_HashDouble, self, static_cast<double>(x));
}

template <typename T>
PyType_Slot CustomFloatType<T>::type_slots[] = {
{Py_tp_new, reinterpret_cast<void*>(PyCustomFloat_New<T>)},
{Py_tp_repr, reinterpret_cast<void*>(PyCustomFloat_Repr<T>)},
{Py_tp_hash, reinterpret_cast<void*>(PyCustomFloat_Hash<T>)},
{Py_tp_str, reinterpret_cast<void*>(PyCustomFloat_Str<T>)},
{Py_tp_doc,
reinterpret_cast<void*>(const_cast<char*>(TypeDescriptor<T>::kTpDoc))},
{Py_tp_richcompare, reinterpret_cast<void*>(PyCustomFloat_RichCompare<T>)},
{Py_nb_add, reinterpret_cast<void*>(PyCustomFloat_Add<T>)},
{Py_nb_subtract, reinterpret_cast<void*>(PyCustomFloat_Subtract<T>)},
{Py_nb_multiply, reinterpret_cast<void*>(PyCustomFloat_Multiply<T>)},
{Py_nb_negative, reinterpret_cast<void*>(PyCustomFloat_Negative<T>)},
{Py_nb_int, reinterpret_cast<void*>(PyCustomFloat_Int<T>)},
{Py_nb_float, reinterpret_cast<void*>(PyCustomFloat_Float<T>)},
{Py_nb_true_divide, reinterpret_cast<void*>(PyCustomFloat_TrueDivide<T>)},
{0, nullptr},
};

template <typename T>
PyType_Spec CustomFloatType<T>::type_spec = {
/*.name=*/TypeDescriptor<T>::kQualifiedTypeName,
/*.basicsize=*/static_cast<int>(sizeof(PyCustomFloat<T>)),
/*.itemsize=*/0,
/*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
/*.slots=*/CustomFloatType<T>::type_slots,
};

// Numpy support
template <typename T>
PyArray_ArrFuncs CustomFloatType<T>::arr_funcs;
Expand Down Expand Up @@ -874,46 +862,19 @@ bool RegisterFloatUFuncs(PyObject* numpy) {

template <typename T>
bool RegisterFloatDtype(PyObject* numpy) {
// TODO(jakevdp): simplify this; we no longer need heap allocation.
Safe_PyObjectPtr name =
make_safe(PyUnicode_FromString(TypeDescriptor<T>::kTypeName));
Safe_PyObjectPtr qualname =
make_safe(PyUnicode_FromString(TypeDescriptor<T>::kTypeName));

PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
PyType_Type.tp_alloc(&PyType_Type, 0));
if (!heap_type) {
return false;
}
// Caution: we must not call any functions that might invoke the GC until
// PyType_Ready() is called. Otherwise the GC might see a half-constructed
// type object.
heap_type->ht_name = name.release();
heap_type->ht_qualname = qualname.release();
PyTypeObject* type = &heap_type->ht_type;
type->tp_name = TypeDescriptor<T>::kTypeName;
type->tp_basicsize = sizeof(PyCustomFloat<T>);
type->tp_flags =
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
type->tp_base = &PyGenericArrType_Type;
type->tp_new = PyCustomFloat_New<T>;
type->tp_repr = PyCustomFloat_Repr<T>;
type->tp_hash = PyCustomFloat_Hash<T>;
type->tp_str = PyCustomFloat_Str<T>;
type->tp_doc = const_cast<char*>(TypeDescriptor<T>::kTpDoc);
type->tp_richcompare = PyCustomFloat_RichCompare<T>;
type->tp_as_number = &CustomFloatType<T>::number_methods;
if (PyType_Ready(type) < 0) {
return false;
}
TypeDescriptor<T>::type_ptr = reinterpret_cast<PyObject*>(type);
PyObject* type = PyType_FromSpecWithBases(
&CustomFloatType<T>::type_spec,
reinterpret_cast<PyObject*>(&PyGenericArrType_Type));
if (!type) {
return false;
}
TypeDescriptor<T>::type_ptr = type;

Safe_PyObjectPtr module = make_safe(PyUnicode_FromString("ml_dtypes"));
if (!module) {
return false;
}
if (PyObject_SetAttrString(TypeDescriptor<T>::type_ptr, "__module__",
module.get()) < 0) {
if (PyObject_SetAttrString(type, "__module__", module.get()) < 0) {
return false;
}

Expand All @@ -940,7 +901,7 @@ bool RegisterFloatDtype(PyObject* numpy) {
PyArray_DescrProto& descr_proto = CustomFloatType<T>::npy_descr_proto;
descr_proto = GetCustomFloatDescrProto<T>();
Py_SET_TYPE(&descr_proto, &PyArrayDescr_Type);
descr_proto.typeobj = type;
descr_proto.typeobj = reinterpret_cast<PyTypeObject*>(type);

TypeDescriptor<T>::npy_type = PyArray_RegisterDataType(&descr_proto);
if (TypeDescriptor<T>::npy_type < 0) {
Expand Down
116 changes: 41 additions & 75 deletions ml_dtypes/_src/intn_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ struct IntNTypeDescriptor {
// registered by another system into NumPy.
static PyObject* type_ptr;

static PyNumberMethods number_methods;
static PyType_Spec type_spec;
static PyType_Slot type_slots[];

static PyArray_ArrFuncs arr_funcs;
static PyArray_DescrProto npy_descr_proto;
static PyArray_Descr* npy_descr;
Expand Down Expand Up @@ -310,47 +312,6 @@ PyObject* PyIntN_nb_floor_divide(PyObject* a, PyObject* b) {
return PyArray_Type.tp_as_number->nb_floor_divide(a, b);
}

// Python number methods for PyIntN objects.
template <typename T>
PyNumberMethods IntNTypeDescriptor<T>::number_methods = {
PyIntN_nb_add<T>, // nb_add
PyIntN_nb_subtract<T>, // nb_subtract
PyIntN_nb_multiply<T>, // nb_multiply
PyIntN_nb_remainder<T>, // nb_remainder
nullptr, // nb_divmod
nullptr, // nb_power
PyIntN_nb_negative<T>, // nb_negative
PyIntN_nb_positive<T>, // nb_positive
nullptr, // nb_absolute
nullptr, // nb_nonzero
nullptr, // nb_invert
nullptr, // nb_lshift
nullptr, // nb_rshift
nullptr, // nb_and
nullptr, // nb_xor
nullptr, // nb_or
PyIntN_nb_int<T>, // nb_int
nullptr, // reserved
PyIntN_nb_float<T>, // nb_float

nullptr, // nb_inplace_add
nullptr, // nb_inplace_subtract
nullptr, // nb_inplace_multiply
nullptr, // nb_inplace_remainder
nullptr, // nb_inplace_power
nullptr, // nb_inplace_lshift
nullptr, // nb_inplace_rshift
nullptr, // nb_inplace_and
nullptr, // nb_inplace_xor
nullptr, // nb_inplace_or

PyIntN_nb_floor_divide<T>, // nb_floor_divide
nullptr, // nb_true_divide
nullptr, // nb_inplace_floor_divide
nullptr, // nb_inplace_true_divide
nullptr, // nb_index
};

// Implementation of repr() for PyIntN.
template <typename T>
PyObject* PyIntN_Repr(PyObject* self) {
Expand Down Expand Up @@ -410,6 +371,36 @@ PyObject* PyIntN_RichCompare(PyObject* a, PyObject* b, int op) {
PyArrayScalar_RETURN_BOOL_FROM_LONG(result);
}

template <typename T>
PyType_Slot IntNTypeDescriptor<T>::type_slots[] = {
{Py_tp_new, reinterpret_cast<void*>(PyIntN_tp_new<T>)},
{Py_tp_repr, reinterpret_cast<void*>(PyIntN_Repr<T>)},
{Py_tp_hash, reinterpret_cast<void*>(PyIntN_Hash<T>)},
{Py_tp_str, reinterpret_cast<void*>(PyIntN_Str<T>)},
{Py_tp_doc,
reinterpret_cast<void*>(const_cast<char*>(TypeDescriptor<T>::kTpDoc))},
{Py_tp_richcompare, reinterpret_cast<void*>(PyIntN_RichCompare<T>)},
{Py_nb_add, reinterpret_cast<void*>(PyIntN_nb_add<T>)},
{Py_nb_subtract, reinterpret_cast<void*>(PyIntN_nb_subtract<T>)},
{Py_nb_multiply, reinterpret_cast<void*>(PyIntN_nb_multiply<T>)},
{Py_nb_remainder, reinterpret_cast<void*>(PyIntN_nb_remainder<T>)},
{Py_nb_negative, reinterpret_cast<void*>(PyIntN_nb_negative<T>)},
{Py_nb_positive, reinterpret_cast<void*>(PyIntN_nb_positive<T>)},
{Py_nb_int, reinterpret_cast<void*>(PyIntN_nb_int<T>)},
{Py_nb_float, reinterpret_cast<void*>(PyIntN_nb_float<T>)},
{Py_nb_floor_divide, reinterpret_cast<void*>(PyIntN_nb_floor_divide<T>)},
{0, nullptr},
};

template <typename T>
PyType_Spec IntNTypeDescriptor<T>::type_spec = {
/*.name=*/TypeDescriptor<T>::kQualifiedTypeName,
/*.basicsize=*/static_cast<int>(sizeof(PyIntN<T>)),
/*.itemsize=*/0,
/*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
/*.slots=*/IntNTypeDescriptor<T>::type_slots,
};

// Numpy support
template <typename T>
PyArray_ArrFuncs IntNTypeDescriptor<T>::arr_funcs;
Expand Down Expand Up @@ -775,38 +766,13 @@ bool RegisterIntNUFuncs(PyObject* numpy) {

template <typename T>
bool RegisterIntNDtype(PyObject* numpy) {
Safe_PyObjectPtr name =
make_safe(PyUnicode_FromString(TypeDescriptor<T>::kTypeName));
Safe_PyObjectPtr qualname =
make_safe(PyUnicode_FromString(TypeDescriptor<T>::kTypeName));

PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
PyType_Type.tp_alloc(&PyType_Type, 0));
if (!heap_type) {
return false;
}
// Caution: we must not call any functions that might invoke the GC until
// PyType_Ready() is called. Otherwise the GC might see a half-constructed
// type object.
heap_type->ht_name = name.release();
heap_type->ht_qualname = qualname.release();
PyTypeObject* type = &heap_type->ht_type;
type->tp_name = TypeDescriptor<T>::kTypeName;
type->tp_basicsize = sizeof(PyIntN<T>);
type->tp_flags =
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
type->tp_base = &PyGenericArrType_Type;
type->tp_new = PyIntN_tp_new<T>;
type->tp_repr = PyIntN_Repr<T>;
type->tp_hash = PyIntN_Hash<T>;
type->tp_str = PyIntN_Str<T>;
type->tp_doc = const_cast<char*>(TypeDescriptor<T>::kTpDoc);
type->tp_richcompare = PyIntN_RichCompare<T>;
type->tp_as_number = &IntNTypeDescriptor<T>::number_methods;
if (PyType_Ready(type) < 0) {
return false;
}
TypeDescriptor<T>::type_ptr = reinterpret_cast<PyObject*>(type);
PyObject* type = PyType_FromSpecWithBases(
&IntNTypeDescriptor<T>::type_spec,
reinterpret_cast<PyObject*>(&PyGenericArrType_Type));
if (!type) {
return false;
}
TypeDescriptor<T>::type_ptr = type;

Safe_PyObjectPtr module = make_safe(PyUnicode_FromString("ml_dtypes"));
if (!module) {
Expand Down Expand Up @@ -840,7 +806,7 @@ bool RegisterIntNDtype(PyObject* numpy) {
PyArray_DescrProto& descr_proto = IntNTypeDescriptor<T>::npy_descr_proto;
descr_proto = GetIntNDescrProto<T>();
Py_SET_TYPE(&descr_proto, &PyArrayDescr_Type);
descr_proto.typeobj = type;
descr_proto.typeobj = reinterpret_cast<PyTypeObject*>(type);

TypeDescriptor<T>::npy_type = PyArray_RegisterDataType(&descr_proto);
if (TypeDescriptor<T>::npy_type < 0) {
Expand Down

0 comments on commit 013b67d

Please sign in to comment.