Skip to content

Commit

Permalink
Refactor ufunc implementations to use variadic template arguments for…
Browse files Browse the repository at this point in the history
… the inputs.

Allows consolidating down to one ufunc implementation for each output arity.

Cleanup only, no functional changes intended.

PiperOrigin-RevId: 675957531
  • Loading branch information
hawkinsp authored and The ml_dtypes Authors committed Sep 18, 2024
1 parent 9a56b09 commit 3a3ccce
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 217 deletions.
197 changes: 91 additions & 106 deletions ml_dtypes/_src/custom_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -741,121 +741,106 @@ bool RegisterFloatCasts() {
template <typename T>
bool RegisterFloatUFuncs(PyObject* numpy) {
bool ok =
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Add<T>>, T>(numpy, "add") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Subtract<T>>, T>(numpy,
"subtract") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Multiply<T>>, T>(numpy,
"multiply") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::TrueDivide<T>>, T>(numpy,
"divide") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::LogAddExp<T>>, T>(numpy,
"logaddexp") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::LogAddExp2<T>>, T>(
numpy, "logaddexp2") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Negative<T>>, T>(numpy,
"negative") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Positive<T>>, T>(numpy,
"positive") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::TrueDivide<T>>, T>(
numpy, "true_divide") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::FloorDivide<T>>, T>(
RegisterUFunc<UFunc<ufuncs::Add<T>, T, T, T>, T>(numpy, "add") &&
RegisterUFunc<UFunc<ufuncs::Subtract<T>, T, T, T>, T>(numpy,
"subtract") &&
RegisterUFunc<UFunc<ufuncs::Multiply<T>, T, T, T>, T>(numpy,
"multiply") &&
RegisterUFunc<UFunc<ufuncs::TrueDivide<T>, T, T, T>, T>(numpy,
"divide") &&
RegisterUFunc<UFunc<ufuncs::LogAddExp<T>, T, T, T>, T>(numpy,
"logaddexp") &&
RegisterUFunc<UFunc<ufuncs::LogAddExp2<T>, T, T, T>, T>(numpy,
"logaddexp2") &&
RegisterUFunc<UFunc<ufuncs::Negative<T>, T, T>, T>(numpy, "negative") &&
RegisterUFunc<UFunc<ufuncs::Positive<T>, T, T>, T>(numpy, "positive") &&
RegisterUFunc<UFunc<ufuncs::TrueDivide<T>, T, T, T>, T>(numpy,
"true_divide") &&
RegisterUFunc<UFunc<ufuncs::FloorDivide<T>, T, T, T>, T>(
numpy, "floor_divide") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Power<T>>, T>(numpy, "power") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Remainder<T>>, T>(numpy,
"remainder") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Remainder<T>>, T>(numpy, "mod") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Fmod<T>>, T>(numpy, "fmod") &&
RegisterUFunc<ufuncs::DivmodUFunc<T>, T>(numpy, "divmod") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Abs<T>>, T>(numpy, "absolute") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Abs<T>>, T>(numpy, "fabs") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Rint<T>>, T>(numpy, "rint") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sign<T>>, T>(numpy, "sign") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Heaviside<T>>, T>(numpy,
"heaviside") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Conjugate<T>>, T>(numpy,
"conjugate") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Exp<T>>, T>(numpy, "exp") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Exp2<T>>, T>(numpy, "exp2") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Expm1<T>>, T>(numpy, "expm1") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log<T>>, T>(numpy, "log") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log2<T>>, T>(numpy, "log2") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log10<T>>, T>(numpy, "log10") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log1p<T>>, T>(numpy, "log1p") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sqrt<T>>, T>(numpy, "sqrt") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Square<T>>, T>(numpy, "square") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Cbrt<T>>, T>(numpy, "cbrt") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Reciprocal<T>>, T>(numpy,
"reciprocal") &&
RegisterUFunc<UFunc<ufuncs::Power<T>, T, T, T>, T>(numpy, "power") &&
RegisterUFunc<UFunc<ufuncs::Remainder<T>, T, T, T>, T>(numpy,
"remainder") &&
RegisterUFunc<UFunc<ufuncs::Remainder<T>, T, T, T>, T>(numpy, "mod") &&
RegisterUFunc<UFunc<ufuncs::Fmod<T>, T, T, T>, T>(numpy, "fmod") &&
RegisterUFunc<UFunc2<ufuncs::Divmod<T>, T, T, T, T>, T>(numpy,
"divmod") &&
RegisterUFunc<UFunc<ufuncs::Abs<T>, T, T>, T>(numpy, "absolute") &&
RegisterUFunc<UFunc<ufuncs::Abs<T>, T, T>, T>(numpy, "fabs") &&
RegisterUFunc<UFunc<ufuncs::Rint<T>, T, T>, T>(numpy, "rint") &&
RegisterUFunc<UFunc<ufuncs::Sign<T>, T, T>, T>(numpy, "sign") &&
RegisterUFunc<UFunc<ufuncs::Heaviside<T>, T, T, T>, T>(numpy,
"heaviside") &&
RegisterUFunc<UFunc<ufuncs::Conjugate<T>, T, T>, T>(numpy, "conjugate") &&
RegisterUFunc<UFunc<ufuncs::Exp<T>, T, T>, T>(numpy, "exp") &&
RegisterUFunc<UFunc<ufuncs::Exp2<T>, T, T>, T>(numpy, "exp2") &&
RegisterUFunc<UFunc<ufuncs::Expm1<T>, T, T>, T>(numpy, "expm1") &&
RegisterUFunc<UFunc<ufuncs::Log<T>, T, T>, T>(numpy, "log") &&
RegisterUFunc<UFunc<ufuncs::Log2<T>, T, T>, T>(numpy, "log2") &&
RegisterUFunc<UFunc<ufuncs::Log10<T>, T, T>, T>(numpy, "log10") &&
RegisterUFunc<UFunc<ufuncs::Log1p<T>, T, T>, T>(numpy, "log1p") &&
RegisterUFunc<UFunc<ufuncs::Sqrt<T>, T, T>, T>(numpy, "sqrt") &&
RegisterUFunc<UFunc<ufuncs::Square<T>, T, T>, T>(numpy, "square") &&
RegisterUFunc<UFunc<ufuncs::Cbrt<T>, T, T>, T>(numpy, "cbrt") &&
RegisterUFunc<UFunc<ufuncs::Reciprocal<T>, T, T>, T>(numpy,
"reciprocal") &&

// Trigonometric functions
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sin<T>>, T>(numpy, "sin") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Cos<T>>, T>(numpy, "cos") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Tan<T>>, T>(numpy, "tan") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arcsin<T>>, T>(numpy, "arcsin") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arccos<T>>, T>(numpy, "arccos") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arctan<T>>, T>(numpy, "arctan") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Arctan2<T>>, T>(numpy,
"arctan2") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Hypot<T>>, T>(numpy, "hypot") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sinh<T>>, T>(numpy, "sinh") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Cosh<T>>, T>(numpy, "cosh") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Tanh<T>>, T>(numpy, "tanh") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arcsinh<T>>, T>(numpy,
"arcsinh") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arccosh<T>>, T>(numpy,
"arccosh") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arctanh<T>>, T>(numpy,
"arctanh") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Deg2rad<T>>, T>(numpy,
"deg2rad") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Rad2deg<T>>, T>(numpy,
"rad2deg") &&
RegisterUFunc<UFunc<ufuncs::Sin<T>, T, T>, T>(numpy, "sin") &&
RegisterUFunc<UFunc<ufuncs::Cos<T>, T, T>, T>(numpy, "cos") &&
RegisterUFunc<UFunc<ufuncs::Tan<T>, T, T>, T>(numpy, "tan") &&
RegisterUFunc<UFunc<ufuncs::Arcsin<T>, T, T>, T>(numpy, "arcsin") &&
RegisterUFunc<UFunc<ufuncs::Arccos<T>, T, T>, T>(numpy, "arccos") &&
RegisterUFunc<UFunc<ufuncs::Arctan<T>, T, T>, T>(numpy, "arctan") &&
RegisterUFunc<UFunc<ufuncs::Arctan2<T>, T, T, T>, T>(numpy, "arctan2") &&
RegisterUFunc<UFunc<ufuncs::Hypot<T>, T, T, T>, T>(numpy, "hypot") &&
RegisterUFunc<UFunc<ufuncs::Sinh<T>, T, T>, T>(numpy, "sinh") &&
RegisterUFunc<UFunc<ufuncs::Cosh<T>, T, T>, T>(numpy, "cosh") &&
RegisterUFunc<UFunc<ufuncs::Tanh<T>, T, T>, T>(numpy, "tanh") &&
RegisterUFunc<UFunc<ufuncs::Arcsinh<T>, T, T>, T>(numpy, "arcsinh") &&
RegisterUFunc<UFunc<ufuncs::Arccosh<T>, T, T>, T>(numpy, "arccosh") &&
RegisterUFunc<UFunc<ufuncs::Arctanh<T>, T, T>, T>(numpy, "arctanh") &&
RegisterUFunc<UFunc<ufuncs::Deg2rad<T>, T, T>, T>(numpy, "deg2rad") &&
RegisterUFunc<UFunc<ufuncs::Rad2deg<T>, T, T>, T>(numpy, "rad2deg") &&

// Comparison functions
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Eq<T>>, T>(numpy, "equal") &&
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Ne<T>>, T>(numpy,
"not_equal") &&
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Lt<T>>, T>(numpy, "less") &&
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Gt<T>>, T>(numpy, "greater") &&
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Le<T>>, T>(numpy,
"less_equal") &&
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Ge<T>>, T>(numpy,
"greater_equal") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Maximum<T>>, T>(numpy,
"maximum") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Minimum<T>>, T>(numpy,
"minimum") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Fmax<T>>, T>(numpy, "fmax") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Fmin<T>>, T>(numpy, "fmin") &&
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::LogicalAnd<T>>, T>(
RegisterUFunc<UFunc<ufuncs::Eq<T>, bool, T, T>, T>(numpy, "equal") &&
RegisterUFunc<UFunc<ufuncs::Ne<T>, bool, T, T>, T>(numpy, "not_equal") &&
RegisterUFunc<UFunc<ufuncs::Lt<T>, bool, T, T>, T>(numpy, "less") &&
RegisterUFunc<UFunc<ufuncs::Gt<T>, bool, T, T>, T>(numpy, "greater") &&
RegisterUFunc<UFunc<ufuncs::Le<T>, bool, T, T>, T>(numpy, "less_equal") &&
RegisterUFunc<UFunc<ufuncs::Ge<T>, bool, T, T>, T>(numpy,
"greater_equal") &&
RegisterUFunc<UFunc<ufuncs::Maximum<T>, T, T, T>, T>(numpy, "maximum") &&
RegisterUFunc<UFunc<ufuncs::Minimum<T>, T, T, T>, T>(numpy, "minimum") &&
RegisterUFunc<UFunc<ufuncs::Fmax<T>, T, T, T>, T>(numpy, "fmax") &&
RegisterUFunc<UFunc<ufuncs::Fmin<T>, T, T, T>, T>(numpy, "fmin") &&
RegisterUFunc<UFunc<ufuncs::LogicalAnd<T>, bool, T, T>, T>(
numpy, "logical_and") &&
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::LogicalOr<T>>, T>(
numpy, "logical_or") &&
RegisterUFunc<BinaryUFunc<T, bool, ufuncs::LogicalXor<T>>, T>(
RegisterUFunc<UFunc<ufuncs::LogicalOr<T>, bool, T, T>, T>(numpy,
"logical_or") &&
RegisterUFunc<UFunc<ufuncs::LogicalXor<T>, bool, T, T>, T>(
numpy, "logical_xor") &&
RegisterUFunc<UnaryUFunc<T, bool, ufuncs::LogicalNot<T>>, T>(
numpy, "logical_not") &&
RegisterUFunc<UFunc<ufuncs::LogicalNot<T>, bool, T>, T>(numpy,
"logical_not") &&

// Floating point functions
RegisterUFunc<UnaryUFunc<T, bool, ufuncs::IsFinite<T>>, T>(numpy,
"isfinite") &&
RegisterUFunc<UnaryUFunc<T, bool, ufuncs::IsInf<T>>, T>(numpy, "isinf") &&
RegisterUFunc<UnaryUFunc<T, bool, ufuncs::IsNan<T>>, T>(numpy, "isnan") &&
RegisterUFunc<UnaryUFunc<T, bool, ufuncs::SignBit<T>>, T>(numpy,
"signbit") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::CopySign<T>>, T>(numpy,
"copysign") &&
RegisterUFunc<UnaryUFunc2<T, T, T, ufuncs::Modf<T>>, T>(numpy, "modf") &&
RegisterUFunc<BinaryUFunc2<T, int, T, ufuncs::Ldexp<T>>, T>(numpy,
"ldexp") &&
RegisterUFunc<UnaryUFunc2<T, T, int, ufuncs::Frexp<T>>, T>(numpy,
"frexp") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Floor<T>>, T>(numpy, "floor") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Ceil<T>>, T>(numpy, "ceil") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Trunc<T>>, T>(numpy, "trunc") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::NextAfter<T>>, T>(numpy,
"nextafter") &&
RegisterUFunc<UnaryUFunc<T, T, ufuncs::Spacing<T>>, T>(numpy, "spacing");
RegisterUFunc<UFunc<ufuncs::IsFinite<T>, bool, T>, T>(numpy,
"isfinite") &&
RegisterUFunc<UFunc<ufuncs::IsInf<T>, bool, T>, T>(numpy, "isinf") &&
RegisterUFunc<UFunc<ufuncs::IsNan<T>, bool, T>, T>(numpy, "isnan") &&
RegisterUFunc<UFunc<ufuncs::SignBit<T>, bool, T>, T>(numpy, "signbit") &&
RegisterUFunc<UFunc<ufuncs::CopySign<T>, T, T, T>, T>(numpy,
"copysign") &&
RegisterUFunc<UFunc2<ufuncs::Modf<T>, T, T, T>, T>(numpy, "modf") &&
RegisterUFunc<UFunc<ufuncs::Ldexp<T>, T, T, int>, T>(numpy, "ldexp") &&
RegisterUFunc<UFunc2<ufuncs::Frexp<T>, T, int, T>, T>(numpy, "frexp") &&
RegisterUFunc<UFunc<ufuncs::Floor<T>, T, T>, T>(numpy, "floor") &&
RegisterUFunc<UFunc<ufuncs::Ceil<T>, T, T>, T>(numpy, "ceil") &&
RegisterUFunc<UFunc<ufuncs::Trunc<T>, T, T>, T>(numpy, "trunc") &&
RegisterUFunc<UFunc<ufuncs::NextAfter<T>, T, T, T>, T>(numpy,
"nextafter") &&
RegisterUFunc<UFunc<ufuncs::Spacing<T>, T, T>, T>(numpy, "spacing");

return ok;
}
Expand Down
16 changes: 8 additions & 8 deletions ml_dtypes/_src/intn_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -751,15 +751,15 @@ bool RegisterIntNCasts() {

template <typename T>
bool RegisterIntNUFuncs(PyObject* numpy) {
bool ok = RegisterUFunc<BinaryUFunc<T, T, ufuncs::Add<T>>, T>(numpy, "add") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Subtract<T>>, T>(
numpy, "subtract") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Multiply<T>>, T>(
numpy, "multiply") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::FloorDivide<T>>, T>(
bool ok = RegisterUFunc<UFunc<ufuncs::Add<T>, T, T, T>, T>(numpy, "add") &&
RegisterUFunc<UFunc<ufuncs::Subtract<T>, T, T, T>, T>(numpy,
"subtract") &&
RegisterUFunc<UFunc<ufuncs::Multiply<T>, T, T, T>, T>(numpy,
"multiply") &&
RegisterUFunc<UFunc<ufuncs::FloorDivide<T>, T, T, T>, T>(
numpy, "floor_divide") &&
RegisterUFunc<BinaryUFunc<T, T, ufuncs::Remainder<T>>, T>(
numpy, "remainder");
RegisterUFunc<UFunc<ufuncs::Remainder<T>, T, T, T>, T>(numpy,
"remainder");

return ok;
}
Expand Down
Loading

0 comments on commit 3a3ccce

Please sign in to comment.