Skip to content

Commit

Permalink
First pass at wrapping ufuncs by default
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 28, 2024
1 parent 59d75b3 commit 196fff2
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 162 deletions.
10 changes: 5 additions & 5 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from jax._src import core
from jax._src import deprecations
from jax._src import dtypes
from jax._src.numpy import ufuncs
from jax._src.numpy.util import (
_broadcast_to, check_arraylike, _complex_elem_type,
promote_dtypes_inexact, promote_dtypes_numeric, _where, implements)
Expand Down Expand Up @@ -2039,9 +2038,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
a_shape = a.shape

if squash_nans:
a = _where(ufuncs.isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
a = lax.sort(a, dimension=axis)
counts = sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims)
counts = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims)
shape_after_reduction = counts.shape
q = lax.expand_dims(
q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim)))
Expand All @@ -2067,7 +2066,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
index[axis] = high
high_value = a[tuple(index)]
else:
a = _where(any(ufuncs.isnan(a), axis=axis, keepdims=True), np.nan, a)
a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a)
a = lax.sort(a, dimension=axis)
n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q))
q = lax.mul(q, n - 1)
Expand Down Expand Up @@ -2223,7 +2222,8 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
Array([1.5, 3. , 4.5], dtype=float32)
"""
check_arraylike("nanpercentile", a, q)
q = ufuncs.true_divide(q, 100.0)
q, = promote_dtypes_inexact(q)
q = q / 100
if not isinstance(interpolation, DeprecatedArg):
deprecations.warn(
"jax-numpy-quantile-interpolation",
Expand Down
91 changes: 24 additions & 67 deletions jax/_src/numpy/ufunc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,19 @@
from __future__ import annotations

from collections.abc import Callable
from functools import partial
from functools import cached_property, partial
import math
import operator
from typing import Any

import jax
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.lax import lax as lax_internal
from jax._src.numpy import reductions
from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take
import jax._src.numpy.lax_numpy as jnp
from jax._src.numpy.reductions import _moveaxis
from jax._src.numpy.util import implements, check_arraylike, _broadcast_to, _where
from jax._src.numpy.util import implements, check_arraylike, promote_args, _broadcast_to, _where
from jax._src.numpy.vectorize import vectorize
from jax._src.util import canonicalize_axis, set_module
from jax._src import pjit
import numpy as np


Expand All @@ -42,40 +40,6 @@
"""


def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> jax.core.Primitive | None:
"""
If fun(*args) lowers to a single primitive with inputs and outputs matching
function inputs and outputs, return that primitive. Otherwise return None.
"""
try:
jaxpr = jax.make_jaxpr(fun)(*args)
except:
return None
while len(jaxpr.eqns) == 1:
eqn = jaxpr.eqns[0]
if (eqn.invars, eqn.outvars) != (jaxpr.jaxpr.invars, jaxpr.jaxpr.outvars):
return None
elif (eqn.primitive == pjit.pjit_p and
all(pjit.is_unspecified(sharding) for sharding in
(*eqn.params['in_shardings'], *eqn.params['out_shardings']))):
jaxpr = jaxpr.eqns[0].params['jaxpr']
else:
return jaxpr.eqns[0].primitive
return None


_primitive_reducers: dict[jax.core.Primitive, Callable[..., Any]] = {
lax_internal.add_p: reductions.sum,
lax_internal.mul_p: reductions.prod,
}


_primitive_accumulators: dict[jax.core.Primitive, Callable[..., Any]] = {
lax_internal.add_p: reductions.cumsum,
lax_internal.mul_p: reductions.cumprod,
}


@set_module('jax.numpy')
class ufunc:
"""Functions that operate element-by-element on whole arrays.
Expand All @@ -86,25 +50,27 @@ def __init__(self, func: Callable[..., Any], /,
nin: int, nout: int, *,
name: str | None = None,
nargs: int | None = None,
identity: Any = None, update_doc=False):
identity: Any = None,
reduce: Callable[..., Any] | None = None,
accumulate: Callable[..., Any] | None = None):
self.__doc__ = func.__doc__
self.__name__ = name or func.__name__
# We want ufunc instances to work properly when marked as static,
# and for this reason it's important that their properties not be
# mutated. We prevent this by storing them in a dunder attribute,
# and accessing them via read-only properties.
if update_doc:
self.__doc__ = func.__doc__
self.__name__ = name or func.__name__
self.__static_props = {
'func': func,
'call': vectorize(func),
'nin': operator.index(nin),
'nout': operator.index(nout),
'nargs': operator.index(nargs or nin),
'identity': identity
'identity': identity,
'reduce': reduce,
'accumulate': accumulate,
}

_func = property(lambda self: self.__static_props['func'])
_call = property(lambda self: self.__static_props['call'])
_call = cached_property(lambda self: jax.jit(vectorize(self._func)))
nin = property(lambda self: self.__static_props['nin'])
nout = property(lambda self: self.__static_props['nout'])
nargs = property(lambda self: self.__static_props['nargs'])
Expand All @@ -124,14 +90,13 @@ def __eq__(self, other: Any) -> bool:
def __repr__(self) -> str:
return f"<jnp.ufunc '{self.__name__}'>"

def __call__(self, *args: ArrayLike,
out: None = None, where: None = None,
**kwargs: Any) -> Any:
@partial(jax.jit, static_argnames=['self'])
def __call__(self, *args: ArrayLike, out: None = None, where: None = None) -> Any:
if out is not None:
raise NotImplementedError(f"out argument of {self}")
if where is not None:
raise NotImplementedError(f"where argument of {self}")
return self._call(*args, **kwargs)
return self._call(*promote_args(self.__name__, *args))

@implements(np.ufunc.reduce, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims'])
Expand All @@ -154,14 +119,10 @@ def reduce(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
"so to use a where mask one has to specify 'initial'.")
if lax_internal._dtype(where) != bool:
raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}")
primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)]))
if primitive is None:
reducer = self._reduce_via_scan
else:
reducer = _primitive_reducers.get(primitive, self._reduce_via_scan)
reducer = self.__static_props['reduce'] or self._reduce_via_scan
return reducer(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)

def _reduce_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
def _reduce_via_scan(self, arr: ArrayLike, axis: int | None = 0, dtype: DTypeLike | None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array:
assert self.nin == 2 and self.nout == 1
Expand Down Expand Up @@ -231,11 +192,7 @@ def accumulate(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None
raise ValueError("accumulate only supported for functions returning a single value")
if out is not None:
raise NotImplementedError(f"out argument of {self.__name__}.accumulate()")
primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)]))
if primitive is None:
accumulator = self._accumulate_via_scan
else:
accumulator = _primitive_accumulators.get(primitive, self._accumulate_via_scan)
accumulator = self.__static_props['accumulate'] or self._accumulate_via_scan
return accumulator(a, axis=axis, dtype=dtype)

def _accumulate_via_scan(self, arr: ArrayLike, axis: int = 0,
Expand Down Expand Up @@ -276,7 +233,7 @@ def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array:
dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype
a = lax_internal.asarray(a).astype(dtype)
args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args)
indices = _eliminate_deprecated_list_indexing(indices)
indices = jnp._eliminate_deprecated_list_indexing(indices)
if not indices:
return a

Expand Down Expand Up @@ -314,7 +271,7 @@ def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0,
dtype: DTypeLike | None = None) -> Array:
check_arraylike(f"{self.__name__}.reduceat", a, indices)
a = lax_internal.asarray(a)
idx_tuple = _eliminate_deprecated_list_indexing(indices)
idx_tuple = jnp._eliminate_deprecated_list_indexing(indices)
assert len(idx_tuple) == 1
indices = idx_tuple[0]
if a.ndim == 0:
Expand All @@ -326,14 +283,14 @@ def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0,
if axis is None or isinstance(axis, (tuple, list)):
raise ValueError("reduceat requires a single integer axis.")
axis = canonicalize_axis(axis, a.ndim)
out = take(a, indices, axis=axis)
ind = jax.lax.expand_dims(append(indices, a.shape[axis]),
out = jnp.take(a, indices, axis=axis)
ind = jax.lax.expand_dims(jnp.append(indices, a.shape[axis]),
list(np.delete(np.arange(out.ndim), axis)))
ind_start = jax.lax.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis)
ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis)
def loop_body(i, out):
return _where((i > ind_start) & (i < ind_end),
self._call(out, take(a, jax.lax.expand_dims(i, (0,)), axis=axis)),
self._call(out, jnp.take(a, jax.lax.expand_dims(i, (0,)), axis=axis)),
out)
return jax.lax.fori_loop(0, a.shape[axis], loop_body, out)

Expand Down Expand Up @@ -363,4 +320,4 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int,
Returns:
wrapped : jax.numpy.ufunc wrapper of func.
"""
return ufunc(func, nin, nout, identity=identity, update_doc=True)
return ufunc(func, nin, nout, identity=identity)
58 changes: 41 additions & 17 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@
from jax._src.custom_derivatives import custom_jvp
from jax._src.lax import lax
from jax._src.lax import other as lax_other
from jax._src.typing import Array, ArrayLike
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.numpy.util import (
check_arraylike, promote_args, promote_args_inexact,
promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric,
promote_shapes, _where, implements, check_no_float0s)
from jax._src.numpy.ufunc_api import ufunc
from jax._src.numpy import reductions

_lax_const = lax._const

Expand Down Expand Up @@ -299,30 +301,25 @@ def cbrt(x: ArrayLike, /) -> Array:
return lax.cbrt(*promote_args_inexact('cbrt', x))

@implements(np.add, module='numpy')
@partial(jit, inline=True)
def add(x: ArrayLike, y: ArrayLike, /) -> Array:
def _add(x: ArrayLike, y: ArrayLike, /) -> Array:
x, y = promote_args("add", x, y)
return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y)

@implements(np.multiply, module='numpy')
@partial(jit, inline=True)
def multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
x, y = promote_args("multiply", x, y)
return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)

@implements(np.bitwise_and, module='numpy')
@partial(jit, inline=True)
def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.bitwise_and(*promote_args("bitwise_and", x, y))

@implements(np.bitwise_or, module='numpy')
@partial(jit, inline=True)
def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.bitwise_or(*promote_args("bitwise_or", x, y))

@implements(np.bitwise_xor, module='numpy')
@partial(jit, inline=True)
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.bitwise_xor(*promote_args("bitwise_xor", x, y))

@implements(np.left_shift, module='numpy')
Expand Down Expand Up @@ -377,18 +374,15 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array:

# Logical ops
@implements(np.logical_and, module='numpy')
@partial(jit, inline=True)
def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y)))

@implements(np.logical_or, module='numpy')
@partial(jit, inline=True)
def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array:
def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y)))

@implements(np.logical_xor, module='numpy')
@partial(jit, inline=True)
def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y)))

@implements(np.logical_not, module='numpy')
Expand Down Expand Up @@ -1281,3 +1275,33 @@ def _sinc_maclaurin(k, x):
def _sinc_maclaurin_jvp(k, primals, tangents):
(x,), (t,) = primals, tangents
return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t


def _logical_and_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None):
if initial is not None:
raise ValueError("initial argument not supported in jnp.logical_and.reduce()")
result = reductions.all(a, axis=axis, out=out, keepdims=keepdims, where=where)
return result if dtype is None else result.astype(dtype)


def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None):
if initial is not None:
raise ValueError("initial argument not supported in jnp.logical_or.reduce()")
result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where)
return result if dtype is None else result.astype(dtype)


# Generate ufunc interfaces for several common binary functions.
# We start with binary ufuncs that have well-defined identities.
add = ufunc(_add, name="add", nin=2, nout=1, identity=0, reduce=reductions.sum, accumulate=reductions.cumsum)
multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, reduce=reductions.prod, accumulate=reductions.cumprod)
bitwise_and = ufunc(_bitwise_and, name="bitwise_and", nin=2, nout=1, identity=-1)
bitwise_or = ufunc(_bitwise_or, name="bitwise_or", nin=2, nout=1, identity=0)
bitwise_xor = ufunc(_bitwise_xor, name="bitwise_xor", nin=2, nout=1, identity=0)
logical_and = ufunc(_logical_and, name="logical_and", nin=2, nout=1, identity=True, reduce=_logical_and_reduce)
logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, reduce=_logical_or_reduce)
logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False)
Loading

0 comments on commit 196fff2

Please sign in to comment.