diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index 2786fde0..9f292eba 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -741,121 +741,106 @@ bool RegisterFloatCasts() { template bool RegisterFloatUFuncs(PyObject* numpy) { bool ok = - RegisterUFunc>, T>(numpy, "add") && - RegisterUFunc>, T>(numpy, - "subtract") && - RegisterUFunc>, T>(numpy, - "multiply") && - RegisterUFunc>, T>(numpy, - "divide") && - RegisterUFunc>, T>(numpy, - "logaddexp") && - RegisterUFunc>, T>( - numpy, "logaddexp2") && - RegisterUFunc>, T>(numpy, - "negative") && - RegisterUFunc>, T>(numpy, - "positive") && - RegisterUFunc>, T>( - numpy, "true_divide") && - RegisterUFunc>, T>( + RegisterUFunc, T, T, T>, T>(numpy, "add") && + RegisterUFunc, T, T, T>, T>(numpy, + "subtract") && + RegisterUFunc, T, T, T>, T>(numpy, + "multiply") && + RegisterUFunc, T, T, T>, T>(numpy, + "divide") && + RegisterUFunc, T, T, T>, T>(numpy, + "logaddexp") && + RegisterUFunc, T, T, T>, T>(numpy, + "logaddexp2") && + RegisterUFunc, T, T>, T>(numpy, "negative") && + RegisterUFunc, T, T>, T>(numpy, "positive") && + RegisterUFunc, T, T, T>, T>(numpy, + "true_divide") && + RegisterUFunc, T, T, T>, T>( numpy, "floor_divide") && - RegisterUFunc>, T>(numpy, "power") && - RegisterUFunc>, T>(numpy, - "remainder") && - RegisterUFunc>, T>(numpy, "mod") && - RegisterUFunc>, T>(numpy, "fmod") && - RegisterUFunc, T>(numpy, "divmod") && - RegisterUFunc>, T>(numpy, "absolute") && - RegisterUFunc>, T>(numpy, "fabs") && - RegisterUFunc>, T>(numpy, "rint") && - RegisterUFunc>, T>(numpy, "sign") && - RegisterUFunc>, T>(numpy, - "heaviside") && - RegisterUFunc>, T>(numpy, - "conjugate") && - RegisterUFunc>, T>(numpy, "exp") && - RegisterUFunc>, T>(numpy, "exp2") && - RegisterUFunc>, T>(numpy, "expm1") && - RegisterUFunc>, T>(numpy, "log") && - RegisterUFunc>, T>(numpy, "log2") && - RegisterUFunc>, T>(numpy, "log10") && - RegisterUFunc>, T>(numpy, "log1p") && - RegisterUFunc>, T>(numpy, "sqrt") && - RegisterUFunc>, T>(numpy, "square") && - RegisterUFunc>, T>(numpy, "cbrt") && - RegisterUFunc>, T>(numpy, - "reciprocal") && + RegisterUFunc, T, T, T>, T>(numpy, "power") && + RegisterUFunc, T, T, T>, T>(numpy, + "remainder") && + RegisterUFunc, T, T, T>, T>(numpy, "mod") && + RegisterUFunc, T, T, T>, T>(numpy, "fmod") && + RegisterUFunc, T, T, T, T>, T>(numpy, + "divmod") && + RegisterUFunc, T, T>, T>(numpy, "absolute") && + RegisterUFunc, T, T>, T>(numpy, "fabs") && + RegisterUFunc, T, T>, T>(numpy, "rint") && + RegisterUFunc, T, T>, T>(numpy, "sign") && + RegisterUFunc, T, T, T>, T>(numpy, + "heaviside") && + RegisterUFunc, T, T>, T>(numpy, "conjugate") && + RegisterUFunc, T, T>, T>(numpy, "exp") && + RegisterUFunc, T, T>, T>(numpy, "exp2") && + RegisterUFunc, T, T>, T>(numpy, "expm1") && + RegisterUFunc, T, T>, T>(numpy, "log") && + RegisterUFunc, T, T>, T>(numpy, "log2") && + RegisterUFunc, T, T>, T>(numpy, "log10") && + RegisterUFunc, T, T>, T>(numpy, "log1p") && + RegisterUFunc, T, T>, T>(numpy, "sqrt") && + RegisterUFunc, T, T>, T>(numpy, "square") && + RegisterUFunc, T, T>, T>(numpy, "cbrt") && + RegisterUFunc, T, T>, T>(numpy, + "reciprocal") && // Trigonometric functions - RegisterUFunc>, T>(numpy, "sin") && - RegisterUFunc>, T>(numpy, "cos") && - RegisterUFunc>, T>(numpy, "tan") && - RegisterUFunc>, T>(numpy, "arcsin") && - RegisterUFunc>, T>(numpy, "arccos") && - RegisterUFunc>, T>(numpy, "arctan") && - RegisterUFunc>, T>(numpy, - "arctan2") && - RegisterUFunc>, T>(numpy, "hypot") && - RegisterUFunc>, T>(numpy, "sinh") && - RegisterUFunc>, T>(numpy, "cosh") && - RegisterUFunc>, T>(numpy, "tanh") && - RegisterUFunc>, T>(numpy, - "arcsinh") && - RegisterUFunc>, T>(numpy, - "arccosh") && - RegisterUFunc>, T>(numpy, - "arctanh") && - RegisterUFunc>, T>(numpy, - "deg2rad") && - RegisterUFunc>, T>(numpy, - "rad2deg") && + RegisterUFunc, T, T>, T>(numpy, "sin") && + RegisterUFunc, T, T>, T>(numpy, "cos") && + RegisterUFunc, T, T>, T>(numpy, "tan") && + RegisterUFunc, T, T>, T>(numpy, "arcsin") && + RegisterUFunc, T, T>, T>(numpy, "arccos") && + RegisterUFunc, T, T>, T>(numpy, "arctan") && + RegisterUFunc, T, T, T>, T>(numpy, "arctan2") && + RegisterUFunc, T, T, T>, T>(numpy, "hypot") && + RegisterUFunc, T, T>, T>(numpy, "sinh") && + RegisterUFunc, T, T>, T>(numpy, "cosh") && + RegisterUFunc, T, T>, T>(numpy, "tanh") && + RegisterUFunc, T, T>, T>(numpy, "arcsinh") && + RegisterUFunc, T, T>, T>(numpy, "arccosh") && + RegisterUFunc, T, T>, T>(numpy, "arctanh") && + RegisterUFunc, T, T>, T>(numpy, "deg2rad") && + RegisterUFunc, T, T>, T>(numpy, "rad2deg") && // Comparison functions - RegisterUFunc>, T>(numpy, "equal") && - RegisterUFunc>, T>(numpy, - "not_equal") && - RegisterUFunc>, T>(numpy, "less") && - RegisterUFunc>, T>(numpy, "greater") && - RegisterUFunc>, T>(numpy, - "less_equal") && - RegisterUFunc>, T>(numpy, - "greater_equal") && - RegisterUFunc>, T>(numpy, - "maximum") && - RegisterUFunc>, T>(numpy, - "minimum") && - RegisterUFunc>, T>(numpy, "fmax") && - RegisterUFunc>, T>(numpy, "fmin") && - RegisterUFunc>, T>( + RegisterUFunc, bool, T, T>, T>(numpy, "equal") && + RegisterUFunc, bool, T, T>, T>(numpy, "not_equal") && + RegisterUFunc, bool, T, T>, T>(numpy, "less") && + RegisterUFunc, bool, T, T>, T>(numpy, "greater") && + RegisterUFunc, bool, T, T>, T>(numpy, "less_equal") && + RegisterUFunc, bool, T, T>, T>(numpy, + "greater_equal") && + RegisterUFunc, T, T, T>, T>(numpy, "maximum") && + RegisterUFunc, T, T, T>, T>(numpy, "minimum") && + RegisterUFunc, T, T, T>, T>(numpy, "fmax") && + RegisterUFunc, T, T, T>, T>(numpy, "fmin") && + RegisterUFunc, bool, T, T>, T>( numpy, "logical_and") && - RegisterUFunc>, T>( - numpy, "logical_or") && - RegisterUFunc>, T>( + RegisterUFunc, bool, T, T>, T>(numpy, + "logical_or") && + RegisterUFunc, bool, T, T>, T>( numpy, "logical_xor") && - RegisterUFunc>, T>( - numpy, "logical_not") && + RegisterUFunc, bool, T>, T>(numpy, + "logical_not") && // Floating point functions - RegisterUFunc>, T>(numpy, - "isfinite") && - RegisterUFunc>, T>(numpy, "isinf") && - RegisterUFunc>, T>(numpy, "isnan") && - RegisterUFunc>, T>(numpy, - "signbit") && - RegisterUFunc>, T>(numpy, - "copysign") && - RegisterUFunc>, T>(numpy, "modf") && - RegisterUFunc>, T>(numpy, - "ldexp") && - RegisterUFunc>, T>(numpy, - "frexp") && - RegisterUFunc>, T>(numpy, "floor") && - RegisterUFunc>, T>(numpy, "ceil") && - RegisterUFunc>, T>(numpy, "trunc") && - RegisterUFunc>, T>(numpy, - "nextafter") && - RegisterUFunc>, T>(numpy, "spacing"); + RegisterUFunc, bool, T>, T>(numpy, + "isfinite") && + RegisterUFunc, bool, T>, T>(numpy, "isinf") && + RegisterUFunc, bool, T>, T>(numpy, "isnan") && + RegisterUFunc, bool, T>, T>(numpy, "signbit") && + RegisterUFunc, T, T, T>, T>(numpy, + "copysign") && + RegisterUFunc, T, T, T>, T>(numpy, "modf") && + RegisterUFunc, T, T, int>, T>(numpy, "ldexp") && + RegisterUFunc, T, int, T>, T>(numpy, "frexp") && + RegisterUFunc, T, T>, T>(numpy, "floor") && + RegisterUFunc, T, T>, T>(numpy, "ceil") && + RegisterUFunc, T, T>, T>(numpy, "trunc") && + RegisterUFunc, T, T, T>, T>(numpy, + "nextafter") && + RegisterUFunc, T, T>, T>(numpy, "spacing"); return ok; } diff --git a/ml_dtypes/_src/intn_numpy.h b/ml_dtypes/_src/intn_numpy.h index ccb4ed63..e184e0b0 100644 --- a/ml_dtypes/_src/intn_numpy.h +++ b/ml_dtypes/_src/intn_numpy.h @@ -751,15 +751,15 @@ bool RegisterIntNCasts() { template bool RegisterIntNUFuncs(PyObject* numpy) { - bool ok = RegisterUFunc>, T>(numpy, "add") && - RegisterUFunc>, T>( - numpy, "subtract") && - RegisterUFunc>, T>( - numpy, "multiply") && - RegisterUFunc>, T>( + bool ok = RegisterUFunc, T, T, T>, T>(numpy, "add") && + RegisterUFunc, T, T, T>, T>(numpy, + "subtract") && + RegisterUFunc, T, T, T>, T>(numpy, + "multiply") && + RegisterUFunc, T, T, T>, T>( numpy, "floor_divide") && - RegisterUFunc>, T>( - numpy, "remainder"); + RegisterUFunc, T, T, T>, T>(numpy, + "remainder"); return ok; } diff --git a/ml_dtypes/_src/ufuncs.h b/ml_dtypes/_src/ufuncs.h index 9672a5ac..eea92196 100644 --- a/ml_dtypes/_src/ufuncs.h +++ b/ml_dtypes/_src/ufuncs.h @@ -21,8 +21,13 @@ limitations under the License. #include "_src/numpy.h" // clang-format on +#include // NOLINT #include // NOLINT #include // NOLINT +#include // NOLINT +#include // NOLINT +#include // NOLINT +#include // NOLINT #include "_src/common.h" // NOLINT @@ -33,100 +38,74 @@ limitations under the License. namespace ml_dtypes { -template -struct UnaryUFunc { +template +struct UFunc { static std::vector Types() { - return {TypeDescriptor::Dtype(), TypeDescriptor::Dtype()}; + return {TypeDescriptor::Dtype()..., + TypeDescriptor::Dtype()}; } - static void Call(char** args, const npy_intp* dimensions, - const npy_intp* steps, void* data) { - const char* i0 = args[0]; - char* o = args[1]; + static constexpr int kInputArity = sizeof...(InTypes); + + template + static void CallImpl(std::index_sequence, char** args, + const npy_intp* dimensions, const npy_intp* steps, + void* data) { + std::array inputs = {args[Is]...}; + char* o = args[kInputArity]; for (npy_intp k = 0; k < *dimensions; k++) { - auto x = *reinterpret_cast::T*>(i0); - *reinterpret_cast::T*>(o) = Functor()(x); - i0 += steps[0]; - o += steps[1]; + *reinterpret_cast(o) = + Functor()(*reinterpret_cast(inputs[Is])...); + ([&]() { inputs[Is] += steps[Is]; }(), ...); + o += steps[kInputArity]; } } -}; - -template -struct UnaryUFunc2 { - static std::vector Types() { - return {TypeDescriptor::Dtype(), TypeDescriptor::Dtype(), - TypeDescriptor::Dtype()}; - } static void Call(char** args, const npy_intp* dimensions, const npy_intp* steps, void* data) { - const char* i0 = args[0]; - char* o0 = args[1]; - char* o1 = args[2]; - for (npy_intp k = 0; k < *dimensions; k++) { - auto x = *reinterpret_cast::T*>(i0); - std::tie(*reinterpret_cast::T*>(o0), - *reinterpret_cast::T*>(o1)) = - Functor()(x); - i0 += steps[0]; - o0 += steps[1]; - o1 += steps[2]; - } + return CallImpl(std::index_sequence_for(), args, dimensions, + steps, data); } }; -template -struct BinaryUFunc { +template +struct UFunc2 { static std::vector Types() { - return {TypeDescriptor::Dtype(), TypeDescriptor::Dtype(), - TypeDescriptor::Dtype()}; - } - static void Call(char** args, const npy_intp* dimensions, - const npy_intp* steps, void* data) { - const char* i0 = args[0]; - const char* i1 = args[1]; - char* o = args[2]; + return { + TypeDescriptor::Dtype()..., + TypeDescriptor::Dtype(), + TypeDescriptor::Dtype(), + }; + } + static constexpr int kInputArity = sizeof...(InTypes); + + template + static void CallImpl(std::index_sequence, char** args, + const npy_intp* dimensions, const npy_intp* steps, + void* data) { + std::array inputs = {args[Is]...}; + char* o0 = args[kInputArity]; + char* o1 = args[kInputArity + 1]; for (npy_intp k = 0; k < *dimensions; k++) { - auto x = *reinterpret_cast::T*>(i0); - auto y = *reinterpret_cast::T*>(i1); - *reinterpret_cast::T*>(o) = - Functor()(x, y); - i0 += steps[0]; - i1 += steps[1]; - o += steps[2]; + std::tie(*reinterpret_cast(o0), + *reinterpret_cast(o1)) = + Functor()(*reinterpret_cast(inputs[Is])...); + ([&]() { inputs[Is] += steps[Is]; }(), ...); + o0 += steps[kInputArity]; + o1 += steps[kInputArity + 1]; } } -}; - -template -struct BinaryUFunc2 { - static std::vector Types() { - return {TypeDescriptor::Dtype(), TypeDescriptor::Dtype(), - TypeDescriptor::Dtype()}; - } static void Call(char** args, const npy_intp* dimensions, const npy_intp* steps, void* data) { - const char* i0 = args[0]; - const char* i1 = args[1]; - char* o = args[2]; - for (npy_intp k = 0; k < *dimensions; k++) { - auto x = *reinterpret_cast::T*>(i0); - auto y = - *reinterpret_cast::T*>(i1); - *reinterpret_cast::T*>(o) = - Functor()(x, y); - i0 += steps[0]; - i1 += steps[1]; - o += steps[2]; - } + return CallImpl(std::index_sequence_for(), args, dimensions, + steps, data); } }; -template +template bool RegisterUFunc(PyObject* numpy, const char* name) { - std::vector types = UFunc::Types(); + std::vector types = UFuncT::Types(); PyUFuncGenericFunction fn = - reinterpret_cast(UFunc::Call); + reinterpret_cast(UFuncT::Call); Safe_PyObjectPtr ufunc_obj = make_safe(PyObject_GetAttrString(numpy, name)); if (!ufunc_obj) { return false; @@ -165,7 +144,7 @@ struct TrueDivide { T operator()(T a, T b) { return a / b; } }; -inline std::pair divmod(float a, float b) { +static std::pair divmod_impl(float a, float b) { if (b == 0.0f) { float nan = std::numeric_limits::quiet_NaN(); float inf = std::numeric_limits::infinity(); @@ -199,6 +178,14 @@ inline std::pair divmod(float a, float b) { return {floordiv, mod}; } +template +struct Divmod { + std::pair operator()(T a, T b) { + float c, d; + std::tie(c, d) = divmod_impl(static_cast(a), static_cast(b)); + return {T(c), T(d)}; + } +}; template struct FloorDivide { template ::is_floating, bool> = true> T operator()(T a, T b) { - return T(divmod(static_cast(a), static_cast(b)).first); + return T(divmod_impl(static_cast(a), static_cast(b)).first); } }; template @@ -240,36 +227,10 @@ struct Remainder { template ::is_floating, bool> = true> T operator()(T a, T b) { - return T(divmod(static_cast(a), static_cast(b)).second); - } -}; -template -struct DivmodUFunc { - static std::vector Types() { - return {TypeDescriptor::Dtype(), TypeDescriptor::Dtype(), - TypeDescriptor::Dtype(), TypeDescriptor::Dtype()}; - } - static void Call(char** args, npy_intp* dimensions, npy_intp* steps, - void* data) { - const char* i0 = args[0]; - const char* i1 = args[1]; - char* o0 = args[2]; - char* o1 = args[3]; - for (npy_intp k = 0; k < *dimensions; k++) { - T x = *reinterpret_cast(i0); - T y = *reinterpret_cast(i1); - float floordiv, mod; - std::tie(floordiv, mod) = - divmod(static_cast(x), static_cast(y)); - *reinterpret_cast(o0) = T(floordiv); - *reinterpret_cast(o1) = T(mod); - i0 += steps[0]; - i1 += steps[1]; - o0 += steps[2]; - o1 += steps[3]; - } + return T(divmod_impl(static_cast(a), static_cast(b)).second); } }; + template struct Fmod { T operator()(T a, T b) {