From 1cba0970d83025574ef991d0d60a38ca8060046d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 28 Aug 2024 10:13:58 -0700 Subject: [PATCH] refactor lax.loops to avoid importing from jax.numpy --- jax/_src/lax/control_flow/loops.py | 4 +- jax/_src/lax/other.py | 66 +++++++++++++++++++++++++---- jax/_src/lax/windowed_reductions.py | 2 +- jax/_src/numpy/ufuncs.py | 45 ++++++++++---------- 4 files changed, 82 insertions(+), 35 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index f7f09424a9e8..5084ec43c2fa 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -50,9 +50,9 @@ _abstractify, _avals_short, _initial_style_jaxpr, _initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros, _typecheck_param) +from jax._src.lax.other import logaddexp from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.numpy.ufuncs import logaddexp from jax._src.state import discharge as state_discharge from jax._src.traceback_util import api_boundary from jax._src.tree_util import equality_errors @@ -2170,7 +2170,7 @@ def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, new_keys = jax.lax.dynamic_update_index_in_dim(keys, new_key, 0, axis=0) return (new_keys, bits), (0, 0) -batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule +batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule # type: ignore[has-type] ### associative_scan diff --git a/jax/_src/lax/other.py b/jax/_src/lax/other.py index 7bdfabb92df8..45f9167ab807 100644 --- a/jax/_src/lax/other.py +++ b/jax/_src/lax/other.py @@ -19,9 +19,11 @@ from typing import Any import jax -from jax._src.numpy import lax_numpy as jnp +from jax._src import dtypes from jax._src.lax import lax from jax._src.lax import convolution +from jax._src import util +import numpy as np DType = Any @@ -88,7 +90,7 @@ def conv_general_dilated_patches( (`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`). """ - lhs_array = jnp.asarray(lhs) + lhs_array = lax.asarray(lhs) filter_shape = tuple(filter_shape) dimension_numbers = convolution.conv_dimension_numbers( lhs_array.shape, (1, 1) + filter_shape, dimension_numbers) @@ -99,11 +101,10 @@ def conv_general_dilated_patches( n_channels = lhs_array.shape[lhs_spec[1]] # Move separate `lhs` spatial locations into separate `rhs` channels. - rhs = jnp.eye(spatial_size, dtype=lhs_array.dtype).reshape(filter_shape * 2) - - rhs = rhs.reshape((spatial_size, 1) + filter_shape) - rhs = jnp.tile(rhs, (n_channels,) + (1,) * (rhs.ndim - 1)) - rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1])) + rhs = lax._eye(lhs_array.dtype, shape=(spatial_size, spatial_size), offset=0) + rhs = lax.broadcast_in_dim(rhs, (n_channels, spatial_size, spatial_size), (1, 2)) + rhs = lax.reshape(rhs, (n_channels * spatial_size, 1, *filter_shape)) + rhs = util.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1])) out = convolution.conv_general_dilated( lhs=lhs_array, @@ -200,7 +201,7 @@ def conv_general_dilated_local( If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')` (for a 2D convolution). """ - lhs_array = jnp.asarray(lhs) + lhs_array = lax.asarray(lhs) c_precision = lax.canonicalize_precision(precision) lhs_precision = ( @@ -234,5 +235,52 @@ def conv_general_dilated_local( dn = ((lhs_c_dims, rhs_c_dims), (lhs_b_dims, rhs_b_dims)) out = lax.dot_general(patches, rhs, dimension_numbers=dn, precision=precision) - out = jnp.moveaxis(out, (-2, -1), (out_spec[0], out_spec[1])) + out = util.moveaxis(out, (-2, -1), (out_spec[0], out_spec[1])) return out + + +def _wrap_between(x, _a): + """Wraps `x` between `[-a, a]`.""" + a = lax._const(x, _a) + two_a = lax._const(x, 2 * _a) + zero = lax._const(x, 0) + rem = lax.rem(lax.add(x, a), two_a) + rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem) + return lax.sub(rem, a) + + +def _replace_inf(x: jax.Array) -> jax.Array: + re_x = lax.real(x) if dtypes.issubdtype(x.dtype, np.complexfloating) else x + inf = lax._const(re_x, float('inf')) + return lax.select(lax.eq(re_x, inf), lax._zeros(x), x) + + +@jax.custom_jvp +def logaddexp(x1: jax.typing.ArrayLike, x2: jax.typing.ArrayLike, /) -> jax.Array: + """Compute log(exp(x1) + exp(x2)) avoiding overflow.""" + x1_arr = lax.asarray(x1) + x2_arr = lax.asarray(x2) + assert x1_arr.dtype == x2_arr.dtype + + amax = lax.max(x1_arr, x2_arr) + if dtypes.isdtype(x1_arr.dtype, "real floating"): + delta = lax.sub(x1_arr, x2_arr) + return lax.select(lax._isnan(delta), + lax.add(x1_arr, x2_arr), # NaNs or infinities of the same sign. + lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) + elif dtypes.isdtype(x1_arr.dtype, "complex floating"): + delta = lax.sub(lax.add(x1, x2), lax.mul(amax, lax._const(amax, 2))) + out = lax.add(amax, lax.log1p(lax.exp(delta))) + return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi)) + else: + raise ValueError(f"logaddexp requires floating-point or complex inputs; got {x1_arr.dtype}") + + +@logaddexp.defjvp +def _logaddexp_jvp(primals, tangents): + x1, x2 = primals + t1, t2 = tangents + primal_out = logaddexp(x1, x2) + tangent_out = lax.add(lax.mul(t1, lax.exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), + lax.mul(t2, lax.exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) + return primal_out, tangent_out diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 5d6eddad0e4d..dd8e664a095a 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -30,9 +30,9 @@ from jax._src.lax import convolution from jax._src.lax import lax from jax._src.lax import slicing +from jax._src.lax.other import logaddexp from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.numpy.ufuncs import logaddexp from jax._src.typing import Array import numpy as np from jax._src.core import ClosedJaxpr diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 36ce9f5135a3..dfeff38df0fe 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -29,6 +29,7 @@ from jax._src.api import jit from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax +from jax._src.lax import other as lax_other from jax._src.typing import Array, ArrayLike from jax._src.numpy.util import ( check_arraylike, promote_args, promote_args_inexact, @@ -857,21 +858,30 @@ def _pow_int_int(x1, x2): return acc -@custom_jvp -@implements(np.logaddexp, module='numpy') @jit def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Compute ``log(exp(x1) + exp(x2))`` avoiding overflow. + + JAX implementation of :func:`numpy.logaddexp` + + Args: + x1: input array + x2: input array + + Returns: + array containing the result. + + Examples: + + >>> x1 = jnp.array([1, 2, 3]) + >>> x2 = jnp.array([4, 5, 6]) + >>> result1 = jnp.logaddexp(x1, x2) + >>> result2 = jnp.log(jnp.exp(x1) + jnp.exp(x2)) + >>> print(jnp.allclose(result1, result2)) + True + """ x1, x2 = promote_args_inexact("logaddexp", x1, x2) - amax = lax.max(x1, x2) - if dtypes.issubdtype(x1.dtype, np.floating): - delta = lax.sub(x1, x2) - return lax.select(lax._isnan(delta), - lax.add(x1, x2), # NaNs or infinities of the same sign. - lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) - else: - delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) - out = lax.add(amax, lax.log1p(lax.exp(delta))) - return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi)) + return lax_other.logaddexp(x1, x2) def _wrap_between(x, _a): @@ -884,17 +894,6 @@ def _wrap_between(x, _a): return lax.sub(rem, a) -@logaddexp.defjvp -def _logaddexp_jvp(primals, tangents): - x1, x2 = primals - t1, t2 = tangents - x1, x2, t1, t2 = promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2) - primal_out = logaddexp(x1, x2) - tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), - lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) - return primal_out, tangent_out - - @custom_jvp @implements(np.logaddexp2, module='numpy') @jit