From 196fff20dc66bfba1edce3bee3dc637da96c7758 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 28 Aug 2024 16:57:57 -0700 Subject: [PATCH] First pass at wrapping ufuncs by default --- jax/_src/numpy/reductions.py | 10 +- jax/_src/numpy/ufunc_api.py | 91 ++++--------- jax/_src/numpy/ufuncs.py | 58 +++++--- jax/numpy/__init__.pyi | 48 +++++-- tests/lax_numpy_ufuncs_test.py | 236 ++++++++++++++++++++++++--------- 5 files changed, 281 insertions(+), 162 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index e8815c943ce6..dddb44dc9207 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -30,7 +30,6 @@ from jax._src import core from jax._src import deprecations from jax._src import dtypes -from jax._src.numpy import ufuncs from jax._src.numpy.util import ( _broadcast_to, check_arraylike, _complex_elem_type, promote_dtypes_inexact, promote_dtypes_numeric, _where, implements) @@ -2039,9 +2038,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, a_shape = a.shape if squash_nans: - a = _where(ufuncs.isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. + a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. a = lax.sort(a, dimension=axis) - counts = sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims) + counts = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims) shape_after_reduction = counts.shape q = lax.expand_dims( q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim))) @@ -2067,7 +2066,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, index[axis] = high high_value = a[tuple(index)] else: - a = _where(any(ufuncs.isnan(a), axis=axis, keepdims=True), np.nan, a) + a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) a = lax.sort(a, dimension=axis) n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q)) q = lax.mul(q, n - 1) @@ -2223,7 +2222,8 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, Array([1.5, 3. , 4.5], dtype=float32) """ check_arraylike("nanpercentile", a, q) - q = ufuncs.true_divide(q, 100.0) + q, = promote_dtypes_inexact(q) + q = q / 100 if not isinstance(interpolation, DeprecatedArg): deprecations.warn( "jax-numpy-quantile-interpolation", diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 2e114193af13..1d0823bea4fb 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -17,7 +17,7 @@ from __future__ import annotations from collections.abc import Callable -from functools import partial +from functools import cached_property, partial import math import operator from typing import Any @@ -25,13 +25,11 @@ import jax from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.lax import lax as lax_internal -from jax._src.numpy import reductions -from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take +import jax._src.numpy.lax_numpy as jnp from jax._src.numpy.reductions import _moveaxis -from jax._src.numpy.util import implements, check_arraylike, _broadcast_to, _where +from jax._src.numpy.util import implements, check_arraylike, promote_args, _broadcast_to, _where from jax._src.numpy.vectorize import vectorize from jax._src.util import canonicalize_axis, set_module -from jax._src import pjit import numpy as np @@ -42,40 +40,6 @@ """ -def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> jax.core.Primitive | None: - """ - If fun(*args) lowers to a single primitive with inputs and outputs matching - function inputs and outputs, return that primitive. Otherwise return None. - """ - try: - jaxpr = jax.make_jaxpr(fun)(*args) - except: - return None - while len(jaxpr.eqns) == 1: - eqn = jaxpr.eqns[0] - if (eqn.invars, eqn.outvars) != (jaxpr.jaxpr.invars, jaxpr.jaxpr.outvars): - return None - elif (eqn.primitive == pjit.pjit_p and - all(pjit.is_unspecified(sharding) for sharding in - (*eqn.params['in_shardings'], *eqn.params['out_shardings']))): - jaxpr = jaxpr.eqns[0].params['jaxpr'] - else: - return jaxpr.eqns[0].primitive - return None - - -_primitive_reducers: dict[jax.core.Primitive, Callable[..., Any]] = { - lax_internal.add_p: reductions.sum, - lax_internal.mul_p: reductions.prod, -} - - -_primitive_accumulators: dict[jax.core.Primitive, Callable[..., Any]] = { - lax_internal.add_p: reductions.cumsum, - lax_internal.mul_p: reductions.cumprod, -} - - @set_module('jax.numpy') class ufunc: """Functions that operate element-by-element on whole arrays. @@ -86,25 +50,27 @@ def __init__(self, func: Callable[..., Any], /, nin: int, nout: int, *, name: str | None = None, nargs: int | None = None, - identity: Any = None, update_doc=False): + identity: Any = None, + reduce: Callable[..., Any] | None = None, + accumulate: Callable[..., Any] | None = None): + self.__doc__ = func.__doc__ + self.__name__ = name or func.__name__ # We want ufunc instances to work properly when marked as static, # and for this reason it's important that their properties not be # mutated. We prevent this by storing them in a dunder attribute, # and accessing them via read-only properties. - if update_doc: - self.__doc__ = func.__doc__ - self.__name__ = name or func.__name__ self.__static_props = { 'func': func, - 'call': vectorize(func), 'nin': operator.index(nin), 'nout': operator.index(nout), 'nargs': operator.index(nargs or nin), - 'identity': identity + 'identity': identity, + 'reduce': reduce, + 'accumulate': accumulate, } _func = property(lambda self: self.__static_props['func']) - _call = property(lambda self: self.__static_props['call']) + _call = cached_property(lambda self: jax.jit(vectorize(self._func))) nin = property(lambda self: self.__static_props['nin']) nout = property(lambda self: self.__static_props['nout']) nargs = property(lambda self: self.__static_props['nargs']) @@ -124,14 +90,13 @@ def __eq__(self, other: Any) -> bool: def __repr__(self) -> str: return f"" - def __call__(self, *args: ArrayLike, - out: None = None, where: None = None, - **kwargs: Any) -> Any: + @partial(jax.jit, static_argnames=['self']) + def __call__(self, *args: ArrayLike, out: None = None, where: None = None) -> Any: if out is not None: raise NotImplementedError(f"out argument of {self}") if where is not None: raise NotImplementedError(f"where argument of {self}") - return self._call(*args, **kwargs) + return self._call(*promote_args(self.__name__, *args)) @implements(np.ufunc.reduce, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims']) @@ -154,14 +119,10 @@ def reduce(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, "so to use a where mask one has to specify 'initial'.") if lax_internal._dtype(where) != bool: raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}") - primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)])) - if primitive is None: - reducer = self._reduce_via_scan - else: - reducer = _primitive_reducers.get(primitive, self._reduce_via_scan) + reducer = self.__static_props['reduce'] or self._reduce_via_scan return reducer(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) - def _reduce_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + def _reduce_via_scan(self, arr: ArrayLike, axis: int | None = 0, dtype: DTypeLike | None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: assert self.nin == 2 and self.nout == 1 @@ -231,11 +192,7 @@ def accumulate(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None raise ValueError("accumulate only supported for functions returning a single value") if out is not None: raise NotImplementedError(f"out argument of {self.__name__}.accumulate()") - primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)])) - if primitive is None: - accumulator = self._accumulate_via_scan - else: - accumulator = _primitive_accumulators.get(primitive, self._accumulate_via_scan) + accumulator = self.__static_props['accumulate'] or self._accumulate_via_scan return accumulator(a, axis=axis, dtype=dtype) def _accumulate_via_scan(self, arr: ArrayLike, axis: int = 0, @@ -276,7 +233,7 @@ def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array: dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype a = lax_internal.asarray(a).astype(dtype) args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args) - indices = _eliminate_deprecated_list_indexing(indices) + indices = jnp._eliminate_deprecated_list_indexing(indices) if not indices: return a @@ -314,7 +271,7 @@ def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0, dtype: DTypeLike | None = None) -> Array: check_arraylike(f"{self.__name__}.reduceat", a, indices) a = lax_internal.asarray(a) - idx_tuple = _eliminate_deprecated_list_indexing(indices) + idx_tuple = jnp._eliminate_deprecated_list_indexing(indices) assert len(idx_tuple) == 1 indices = idx_tuple[0] if a.ndim == 0: @@ -326,14 +283,14 @@ def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0, if axis is None or isinstance(axis, (tuple, list)): raise ValueError("reduceat requires a single integer axis.") axis = canonicalize_axis(axis, a.ndim) - out = take(a, indices, axis=axis) - ind = jax.lax.expand_dims(append(indices, a.shape[axis]), + out = jnp.take(a, indices, axis=axis) + ind = jax.lax.expand_dims(jnp.append(indices, a.shape[axis]), list(np.delete(np.arange(out.ndim), axis))) ind_start = jax.lax.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis) ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis) def loop_body(i, out): return _where((i > ind_start) & (i < ind_end), - self._call(out, take(a, jax.lax.expand_dims(i, (0,)), axis=axis)), + self._call(out, jnp.take(a, jax.lax.expand_dims(i, (0,)), axis=axis)), out) return jax.lax.fori_loop(0, a.shape[axis], loop_body, out) @@ -363,4 +320,4 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, Returns: wrapped : jax.numpy.ufunc wrapper of func. """ - return ufunc(func, nin, nout, identity=identity, update_doc=True) + return ufunc(func, nin, nout, identity=identity) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index dfeff38df0fe..dca31b94ad3d 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -30,11 +30,13 @@ from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax from jax._src.lax import other as lax_other -from jax._src.typing import Array, ArrayLike +from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.numpy.util import ( check_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, promote_shapes, _where, implements, check_no_float0s) +from jax._src.numpy.ufunc_api import ufunc +from jax._src.numpy import reductions _lax_const = lax._const @@ -299,30 +301,25 @@ def cbrt(x: ArrayLike, /) -> Array: return lax.cbrt(*promote_args_inexact('cbrt', x)) @implements(np.add, module='numpy') -@partial(jit, inline=True) -def add(x: ArrayLike, y: ArrayLike, /) -> Array: +def _add(x: ArrayLike, y: ArrayLike, /) -> Array: x, y = promote_args("add", x, y) return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) @implements(np.multiply, module='numpy') -@partial(jit, inline=True) -def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: +def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array: x, y = promote_args("multiply", x, y) return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) @implements(np.bitwise_and, module='numpy') -@partial(jit, inline=True) -def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: +def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.bitwise_and(*promote_args("bitwise_and", x, y)) @implements(np.bitwise_or, module='numpy') -@partial(jit, inline=True) -def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: +def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.bitwise_or(*promote_args("bitwise_or", x, y)) @implements(np.bitwise_xor, module='numpy') -@partial(jit, inline=True) -def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: +def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.bitwise_xor(*promote_args("bitwise_xor", x, y)) @implements(np.left_shift, module='numpy') @@ -377,18 +374,15 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: # Logical ops @implements(np.logical_and, module='numpy') -@partial(jit, inline=True) -def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: +def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y))) @implements(np.logical_or, module='numpy') -@partial(jit, inline=True) -def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: +def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y))) @implements(np.logical_xor, module='numpy') -@partial(jit, inline=True) -def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: +def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y))) @implements(np.logical_not, module='numpy') @@ -1281,3 +1275,33 @@ def _sinc_maclaurin(k, x): def _sinc_maclaurin_jvp(k, primals, tangents): (x,), (t,) = primals, tangents return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t + + +def _logical_and_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None): + if initial is not None: + raise ValueError("initial argument not supported in jnp.logical_and.reduce()") + result = reductions.all(a, axis=axis, out=out, keepdims=keepdims, where=where) + return result if dtype is None else result.astype(dtype) + + +def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None): + if initial is not None: + raise ValueError("initial argument not supported in jnp.logical_or.reduce()") + result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where) + return result if dtype is None else result.astype(dtype) + + +# Generate ufunc interfaces for several common binary functions. +# We start with binary ufuncs that have well-defined identities. +add = ufunc(_add, name="add", nin=2, nout=1, identity=0, reduce=reductions.sum, accumulate=reductions.cumsum) +multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, reduce=reductions.prod, accumulate=reductions.cumprod) +bitwise_and = ufunc(_bitwise_and, name="bitwise_and", nin=2, nout=1, identity=-1) +bitwise_or = ufunc(_bitwise_or, name="bitwise_or", nin=2, nout=1, identity=0) +bitwise_xor = ufunc(_bitwise_xor, name="bitwise_xor", nin=2, nout=1, identity=0) +logical_and = ufunc(_logical_and, name="logical_and", nin=2, nout=1, identity=True, reduce=_logical_and_reduce) +logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, reduce=_logical_or_reduce) +logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 583f6886e915..96cd458ee32a 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -3,7 +3,7 @@ from __future__ import annotations import builtins from collections.abc import Callable, Sequence -from typing import Any, Literal, NamedTuple, TypeVar, Union, overload +from typing import Any, Literal, NamedTuple, Protocol, TypeVar, Union, overload from jax._src import core as _core from jax._src import dtypes as _dtypes @@ -28,6 +28,34 @@ _Device = Device ComplexWarning: type +class BinaryUfunc(Protocol): + @property + def nin(self) -> int: ... + @property + def nout(self) -> int: ... + @property + def nargs(self) -> int: ... + @property + def identity(self) -> builtins.bool | int | float: ... + def __call__(self, x: ArrayLike, y: ArrayLike, /) -> Array: ... + def reduce(self, arr: ArrayLike, /, *, + axis: int | None = 0, + dtype: DTypeLike | None = None, + keepdims: builtins.bool = False, + initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: ... + def accumulate(self, a: ArrayLike, /, *, + axis: int = 0, + dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, + inplace: builtins.bool = True) -> Array: ... + def reduceat(self, a: ArrayLike, indices: Any, *, + axis: int = 0, + dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def outer(self, a: ArrayLike, b: ArrayLike, /, **kwargs) -> Array: ... + __array_api_version__: str def __array_namespace_info__() -> ArrayNamespaceInfo: ... @@ -36,7 +64,7 @@ def abs(x: ArrayLike, /) -> Array: ... def absolute(x: ArrayLike, /) -> Array: ... def acos(x: ArrayLike, /) -> Array: ... def acosh(x: ArrayLike, /) -> Array: ... -def add(x: ArrayLike, y: ArrayLike, /) -> Array: ... +add: BinaryUfunc def amax(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... @@ -162,14 +190,14 @@ def bartlett(M: int) -> Array: ... bfloat16: Any def bincount(x: ArrayLike, weights: ArrayLike | None = ..., minlength: int = ..., *, length: int | None = ...) -> Array: ... -def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: ... +bitwise_and: BinaryUfunc def bitwise_count(x: ArrayLike, /) -> Array: ... def bitwise_invert(x: ArrayLike, /) -> Array: ... def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_not(x: ArrayLike, /) -> Array: ... -def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... +bitwise_or: BinaryUfunc def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... +bitwise_xor: BinaryUfunc def blackman(M: int) -> Array: ... def block(arrays: ArrayLike | Sequence[ArrayLike] | Sequence[Sequence[ArrayLike]]) -> Array: ... bool: Any @@ -251,7 +279,7 @@ def cumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ..., dtype: DTypeLike | None = ..., - include_initial: bool = ...) -> Array: ... + include_initial: builtins.bool = ...) -> Array: ... def deg2rad(x: ArrayLike, /) -> Array: ... degrees = rad2deg @@ -557,10 +585,10 @@ def log1p(x: ArrayLike, /) -> Array: ... def log2(x: ArrayLike, /) -> Array: ... def logaddexp(x: ArrayLike, y: ArrayLike, /) -> Array: ... def logaddexp2(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: ... +logical_and: BinaryUfunc def logical_not(x: ArrayLike, /) -> Array: ... -def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... +logical_or: BinaryUfunc +logical_xor: BinaryUfunc def logspace(start: ArrayLike, stop: ArrayLike, num: int = ..., endpoint: builtins.bool = ..., base: ArrayLike = ..., dtype: DTypeLike | None = ..., axis: int = ...) -> Array: ... @@ -588,7 +616,7 @@ def mod(x: ArrayLike, y: ArrayLike, /) -> Array: ... def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: ... def moveaxis(a: ArrayLike, source: int | Sequence[int], destination: int | Sequence[int]) -> Array: ... -def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: ... +multiply: BinaryUfunc nan: float def nan_to_num(x: ArrayLike, copy: builtins.bool = ..., nan: ArrayLike = ..., posinf: ArrayLike | None = ..., diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index 16eb9321c822..e2450433ecb7 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -14,6 +14,7 @@ """Tests for jax.numpy.ufunc and its methods.""" +import itertools from functools import partial from absl.testing import absltest @@ -22,7 +23,6 @@ import jax import jax.numpy as jnp from jax._src import test_util as jtu -from jax._src.numpy.ufunc_api import get_if_single_primitive jax.config.parse_flags_with_absl() @@ -54,19 +54,22 @@ def scalar_sub(x, y): {'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None}, ] -FASTPATH_FUNCS = [ - {'func': jnp.add, 'nin': 2, 'nout': 1, 'identity': 0, - 'reducer': jax.lax.reduce_sum_p, 'accumulator': jax.lax.cumsum_p}, - {'func': jnp.multiply, 'nin': 2, 'nout': 1, 'identity': 1, - 'reducer': jax.lax.reduce_prod_p, 'accumulator': jax.lax.cumprod_p}, -] +def _jnp_ufunc_props(name): + jnp_func = getattr(jnp, name) + assert isinstance(jnp_func, jnp.ufunc) + np_func = getattr(np, name) + dtypes = [np.dtype(c) for c in "Ffi?" if f"{c}{c}->{c}" in np_func.types] + return [dict(name=name, dtype=dtype) for dtype in dtypes] + -NON_FASTPATH_FUNCS = [ - {'func': lambda a, b: jnp.add(a, a), 'nin': 2, 'nout': 1, 'identity': 0}, - {'func': lambda a, b: jnp.multiply(b, a), 'nin': 2, 'nout': 1, 'identity': 1}, - {'func': jax.jit(lambda a, b: jax.jit(jnp.multiply)(b, a)), 'nin': 2, 'nout': 1, 'identity': 1}, +JAX_NUMPY_UFUNCS = [ + name for name in dir(jnp) if isinstance(getattr(jnp, name), jnp.ufunc) ] +JAX_NUMPY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable( + _jnp_ufunc_props(name) for name in JAX_NUMPY_UFUNCS +)) + broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)] nonscalar_shapes = [(3,), (4,), (4, 3)] @@ -80,23 +83,40 @@ def wrapped(*args, **kwargs): class LaxNumpyUfuncTests(jtu.JaxTestCase): @jtu.sample_product(SCALAR_FUNCS) - def test_ufunc_properties(self, func, nin, nout, identity): + def test_frompyfunc_properties(self, func, nin, nout, identity): jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) self.assertEqual(jnp_fun.identity, identity) self.assertEqual(jnp_fun.nin, nin) self.assertEqual(jnp_fun.nout, nout) self.assertEqual(jnp_fun.nargs, nin) + @jtu.sample_product(name=JAX_NUMPY_UFUNCS) + def test_ufunc_properties(self, name): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + self.assertEqual(jnp_fun.identity, np_fun.identity) + self.assertEqual(jnp_fun.nin, np_fun.nin) + self.assertEqual(jnp_fun.nout, np_fun.nout) + self.assertEqual(jnp_fun.nargs, np_fun.nargs - 1) # -1 because NumPy accepts `out` + @jtu.sample_product(SCALAR_FUNCS) - def test_ufunc_properties_readonly(self, func, nin, nout, identity): + def test_frompyfunc_properties_readonly(self, func, nin, nout, identity): jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) - for attr in ['nargs', 'nin', 'nout', 'identity', '_func', '_call']: + for attr in ['nargs', 'nin', 'nout', 'identity', '_func']: + getattr(jnp_fun, attr) # no error on attribute access. + with self.assertRaises(AttributeError): + setattr(jnp_fun, attr, None) # error when trying to mutate. + + @jtu.sample_product(name=JAX_NUMPY_UFUNCS) + def test_ufunc_properties_readonly(self, name): + jnp_fun = getattr(jnp, name) + for attr in ['nargs', 'nin', 'nout', 'identity', '_func']: getattr(jnp_fun, attr) # no error on attribute access. with self.assertRaises(AttributeError): setattr(jnp_fun, attr, None) # error when trying to mutate. @jtu.sample_product(SCALAR_FUNCS) - def test_ufunc_hash(self, func, nin, nout, identity): + def test_frompyfunc_hash(self, func, nin, nout, identity): jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) jnp_fun_2 = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) self.assertEqual(jnp_fun, jnp_fun_2) @@ -113,7 +133,7 @@ def test_ufunc_hash(self, func, nin, nout, identity): dtype=jtu.dtypes.floating, ) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def test_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): + def test_frompyfunc_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) np_fun = cast_outputs(np.frompyfunc(func, nin=nin, nout=nout, identity=identity)) @@ -123,13 +143,28 @@ def test_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + JAX_NUMPY_UFUNCS_WITH_DTYPES, + lhs_shape=broadcast_compatible_shapes, + rhs_shape=broadcast_compatible_shapes, + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def test_ufunc_call(self, name, dtype, lhs_shape, rhs_shape): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( SCALAR_FUNCS, lhs_shape=broadcast_compatible_shapes, rhs_shape=broadcast_compatible_shapes, dtype=jtu.dtypes.floating, ) - def test_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): + def test_frompyfunc_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): if (nin, nout) != (2, 1): self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).outer @@ -141,6 +176,23 @@ def test_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + JAX_NUMPY_UFUNCS_WITH_DTYPES, + lhs_shape=broadcast_compatible_shapes, + rhs_shape=broadcast_compatible_shapes, + ) + def test_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + + self._CheckAgainstNumpy(jnp_fun.outer, np_fun.outer, args_maker) + self._CompileAndCheck(jnp_fun.outer, args_maker) + @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} @@ -148,7 +200,7 @@ def test_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): for axis in [None, *range(-len(shape), len(shape))]], dtype=jtu.dtypes.floating, ) - def test_reduce(self, func, nin, nout, identity, shape, axis, dtype): + def test_frompyfunc_reduce(self, func, nin, nout, identity, shape, axis, dtype): if (nin, nout) != (2, 1): self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) @@ -160,6 +212,26 @@ def test_reduce(self, func, nin, nout, identity, shape, axis, dtype): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + JAX_NUMPY_UFUNCS_WITH_DTYPES, + [{'shape': shape, 'axis': axis} + for shape in nonscalar_shapes + for axis in [None, *range(-len(shape), len(shape))]], + ) + def test_ufunc_reduce(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + jnp_fun_reduce = partial(jnp_fun.reduce, axis=axis) + np_fun_reduce = partial(np_fun.reduce, axis=axis) + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker) + self._CompileAndCheck(jnp_fun_reduce, args_maker) + @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} @@ -167,7 +239,7 @@ def test_reduce(self, func, nin, nout, identity, shape, axis, dtype): for axis in [None, *range(-len(shape), len(shape))]], dtype=jtu.dtypes.floating, ) - def test_reduce_where(self, func, nin, nout, identity, shape, axis, dtype): + def test_frompyfunc_reduce_where(self, func, nin, nout, identity, shape, axis, dtype): if (nin, nout) != (2, 1): self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") @@ -194,42 +266,28 @@ def np_fun(arr, where): self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - FASTPATH_FUNCS, + JAX_NUMPY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} for shape in nonscalar_shapes - for axis in range(-len(shape), len(shape))], - dtype=jtu.dtypes.floating, - ) - def test_reduce_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator): - del accumulator # unused - if (nin, nout) != (2, 1): - self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") - rng = jtu.rand_default(self.rng()) - args = (rng(shape, dtype),) - jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) - self.assertEqual(get_if_single_primitive(jnp_fun, *args), reducer) - - @jtu.sample_product( - NON_FASTPATH_FUNCS, - [{'shape': shape, 'axis': axis} - for shape in nonscalar_shapes - for axis in range(-len(shape), len(shape))], - dtype=jtu.dtypes.floating, + for axis in [None, *range(-len(shape), len(shape))]], ) - def test_non_fastpath(self, func, nin, nout, identity, shape, axis, dtype): - if (nin, nout) != (2, 1): - self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") - rng = jtu.rand_default(self.rng()) - args = (rng(shape, dtype),) - - _ = func(0, 0) # function should not error. + def test_ufunc_reduce_where(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + if jnp_fun.identity is None: + self.skipTest("reduce with where requires identity") - reduce_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) - self.assertIsNone(get_if_single_primitive(reduce_fun, *args)) + jnp_fun_reduce = lambda a, where: jnp_fun.reduce(a, axis=axis, where=where) + np_fun_reduce = lambda a, where: np_fun.reduce(a, axis=axis, where=where) - accum_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis) - self.assertIsNone(get_if_single_primitive(accum_fun, *args)) + rng = jtu.rand_default(self.rng()) + rng_where = jtu.rand_bool(self.rng()) + args_maker = lambda: [rng(shape, dtype), rng_where(shape, bool)] + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker) + self._CompileAndCheck(jnp_fun_reduce, args_maker) @jtu.sample_product( SCALAR_FUNCS, @@ -238,7 +296,7 @@ def test_non_fastpath(self, func, nin, nout, identity, shape, axis, dtype): for axis in range(-len(shape), len(shape))], dtype=jtu.dtypes.floating, ) - def test_accumulate(self, func, nin, nout, identity, shape, axis, dtype): + def test_frompyfunc_accumulate(self, func, nin, nout, identity, shape, axis, dtype): if (nin, nout) != (2, 1): self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis) @@ -251,20 +309,25 @@ def test_accumulate(self, func, nin, nout, identity, shape, axis, dtype): self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - FASTPATH_FUNCS, + JAX_NUMPY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} for shape in nonscalar_shapes - for axis in range(-len(shape), len(shape))], - dtype=jtu.dtypes.floating, + for axis in range(-len(shape), len(shape))] ) - def test_accumulate_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator): - del reducer # unused - if (nin, nout) != (2, 1): - self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") + def test_ufunc_accumulate(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + rng = jtu.rand_default(self.rng()) - args = (rng(shape, dtype),) - jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis) - self.assertEqual(get_if_single_primitive(jnp_fun, *args), accumulator) + args_maker = lambda: [rng(shape, dtype)] + + jnp_fun_accumulate = partial(jnp_fun.accumulate, axis=axis) + np_fun_accumulate = partial(np_fun.accumulate, axis=axis) + + self._CheckAgainstNumpy(jnp_fun_accumulate, np_fun_accumulate, args_maker) + self._CompileAndCheck(jnp_fun_accumulate, args_maker) @jtu.sample_product( SCALAR_FUNCS, @@ -272,7 +335,7 @@ def test_accumulate_fastpath(self, func, nin, nout, identity, shape, axis, dtype idx_shape=[(), (2,)], dtype=jtu.dtypes.floating, ) - def test_at(self, func, nin, nout, identity, shape, idx_shape, dtype): + def test_frompyfunc_at(self, func, nin, nout, identity, shape, idx_shape, dtype): if (nin, nout) != (2, 1): self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).at, inplace=False) @@ -288,7 +351,31 @@ def np_fun(x, idx, y): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) - def test_at_broadcasting(self): + @jtu.sample_product( + JAX_NUMPY_UFUNCS_WITH_DTYPES, + shape=nonscalar_shapes, + idx_shape=[(), (2,)], + ) + def test_ufunc_at(self, name, shape, idx_shape, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + + rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0]) + args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32'), rng(idx_shape[1:], dtype)] + + jnp_fun_at = partial(jnp_fun.at, inplace=False) + def np_fun_at(x, idx, y): + x_copy = x.copy() + np_fun.at(x_copy, idx, y) + return x_copy + + self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker) + self._CompileAndCheck(jnp_fun_at, args_maker) + + def test_frompyfunc_at_broadcasting(self): # Regression test for https://github.com/google/jax/issues/18004 args_maker = lambda: [np.ones((5, 3)), np.array([0, 4, 2]), np.arange(9.0).reshape(3, 3)] @@ -309,7 +396,7 @@ def np_fun(x, idx, y): idx_shape=[(0,), (3,), (5,)], dtype=jtu.dtypes.floating, ) - def test_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype): + def test_frompyfunc_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype): if (nin, nout) != (2, 1): self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduceat, axis=axis) @@ -322,6 +409,29 @@ def test_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + JAX_NUMPY_UFUNCS_WITH_DTYPES, + [{'shape': shape, 'axis': axis} + for shape in nonscalar_shapes + for axis in [*range(-len(shape), len(shape))]], + idx_shape=[(0,), (3,), (5,)], + ) + def test_ufunc_reduceat(self, name, shape, axis, idx_shape, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + if name in ['add', 'multiply'] and dtype == bool: + # TODO(jakevdp): figure out how to fix thest cases. + self.skipTest(f"known failure for {name}.reduceat with {dtype=}") + + rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_int(self.rng(), low=0, high=shape[axis]) + args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')] + + self._CheckAgainstNumpy(jnp_fun.reduceat, np_fun.reduceat, args_maker) + self._CompileAndCheck(jnp_fun.reduceat, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())