Skip to content

Commit

Permalink
refactor lax.loops to avoid importing from jax.numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 28, 2024
1 parent 26619e2 commit 1cba097
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 35 deletions.
4 changes: 2 additions & 2 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@
_abstractify, _avals_short, _initial_style_jaxpr,
_initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros,
_typecheck_param)
from jax._src.lax.other import logaddexp
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.ufuncs import logaddexp
from jax._src.state import discharge as state_discharge
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import equality_errors
Expand Down Expand Up @@ -2170,7 +2170,7 @@ def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype,
new_keys = jax.lax.dynamic_update_index_in_dim(keys, new_key, 0, axis=0)
return (new_keys, bits), (0, 0)

batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule
batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule # type: ignore[has-type]

### associative_scan

Expand Down
66 changes: 57 additions & 9 deletions jax/_src/lax/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
from typing import Any

import jax
from jax._src.numpy import lax_numpy as jnp
from jax._src import dtypes
from jax._src.lax import lax
from jax._src.lax import convolution
from jax._src import util
import numpy as np

DType = Any

Expand Down Expand Up @@ -88,7 +90,7 @@ def conv_general_dilated_patches(
(`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`).
"""
lhs_array = jnp.asarray(lhs)
lhs_array = lax.asarray(lhs)
filter_shape = tuple(filter_shape)
dimension_numbers = convolution.conv_dimension_numbers(
lhs_array.shape, (1, 1) + filter_shape, dimension_numbers)
Expand All @@ -99,11 +101,10 @@ def conv_general_dilated_patches(
n_channels = lhs_array.shape[lhs_spec[1]]

# Move separate `lhs` spatial locations into separate `rhs` channels.
rhs = jnp.eye(spatial_size, dtype=lhs_array.dtype).reshape(filter_shape * 2)

rhs = rhs.reshape((spatial_size, 1) + filter_shape)
rhs = jnp.tile(rhs, (n_channels,) + (1,) * (rhs.ndim - 1))
rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1]))
rhs = lax._eye(lhs_array.dtype, shape=(spatial_size, spatial_size), offset=0)
rhs = lax.broadcast_in_dim(rhs, (n_channels, spatial_size, spatial_size), (1, 2))
rhs = lax.reshape(rhs, (n_channels * spatial_size, 1, *filter_shape))
rhs = util.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1]))

out = convolution.conv_general_dilated(
lhs=lhs_array,
Expand Down Expand Up @@ -200,7 +201,7 @@ def conv_general_dilated_local(
If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')`
(for a 2D convolution).
"""
lhs_array = jnp.asarray(lhs)
lhs_array = lax.asarray(lhs)

c_precision = lax.canonicalize_precision(precision)
lhs_precision = (
Expand Down Expand Up @@ -234,5 +235,52 @@ def conv_general_dilated_local(

dn = ((lhs_c_dims, rhs_c_dims), (lhs_b_dims, rhs_b_dims))
out = lax.dot_general(patches, rhs, dimension_numbers=dn, precision=precision)
out = jnp.moveaxis(out, (-2, -1), (out_spec[0], out_spec[1]))
out = util.moveaxis(out, (-2, -1), (out_spec[0], out_spec[1]))
return out


def _wrap_between(x, _a):
"""Wraps `x` between `[-a, a]`."""
a = lax._const(x, _a)
two_a = lax._const(x, 2 * _a)
zero = lax._const(x, 0)
rem = lax.rem(lax.add(x, a), two_a)
rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem)
return lax.sub(rem, a)


def _replace_inf(x: jax.Array) -> jax.Array:
re_x = lax.real(x) if dtypes.issubdtype(x.dtype, np.complexfloating) else x
inf = lax._const(re_x, float('inf'))
return lax.select(lax.eq(re_x, inf), lax._zeros(x), x)


@jax.custom_jvp
def logaddexp(x1: jax.typing.ArrayLike, x2: jax.typing.ArrayLike, /) -> jax.Array:
"""Compute log(exp(x1) + exp(x2)) avoiding overflow."""
x1_arr = lax.asarray(x1)
x2_arr = lax.asarray(x2)
assert x1_arr.dtype == x2_arr.dtype

amax = lax.max(x1_arr, x2_arr)
if dtypes.isdtype(x1_arr.dtype, "real floating"):
delta = lax.sub(x1_arr, x2_arr)
return lax.select(lax._isnan(delta),
lax.add(x1_arr, x2_arr), # NaNs or infinities of the same sign.
lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
elif dtypes.isdtype(x1_arr.dtype, "complex floating"):
delta = lax.sub(lax.add(x1, x2), lax.mul(amax, lax._const(amax, 2)))
out = lax.add(amax, lax.log1p(lax.exp(delta)))
return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
else:
raise ValueError(f"logaddexp requires floating-point or complex inputs; got {x1_arr.dtype}")


@logaddexp.defjvp
def _logaddexp_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
primal_out = logaddexp(x1, x2)
tangent_out = lax.add(lax.mul(t1, lax.exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
lax.mul(t2, lax.exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
return primal_out, tangent_out
2 changes: 1 addition & 1 deletion jax/_src/lax/windowed_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
from jax._src.lax import convolution
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lax.other import logaddexp
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.ufuncs import logaddexp
from jax._src.typing import Array
import numpy as np
from jax._src.core import ClosedJaxpr
Expand Down
45 changes: 22 additions & 23 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from jax._src.api import jit
from jax._src.custom_derivatives import custom_jvp
from jax._src.lax import lax
from jax._src.lax import other as lax_other
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import (
check_arraylike, promote_args, promote_args_inexact,
Expand Down Expand Up @@ -857,21 +858,30 @@ def _pow_int_int(x1, x2):
return acc


@custom_jvp
@implements(np.logaddexp, module='numpy')
@jit
def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Compute ``log(exp(x1) + exp(x2))`` avoiding overflow.
JAX implementation of :func:`numpy.logaddexp`
Args:
x1: input array
x2: input array
Returns:
array containing the result.
Examples:
>>> x1 = jnp.array([1, 2, 3])
>>> x2 = jnp.array([4, 5, 6])
>>> result1 = jnp.logaddexp(x1, x2)
>>> result2 = jnp.log(jnp.exp(x1) + jnp.exp(x2))
>>> print(jnp.allclose(result1, result2))
True
"""
x1, x2 = promote_args_inexact("logaddexp", x1, x2)
amax = lax.max(x1, x2)
if dtypes.issubdtype(x1.dtype, np.floating):
delta = lax.sub(x1, x2)
return lax.select(lax._isnan(delta),
lax.add(x1, x2), # NaNs or infinities of the same sign.
lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
else:
delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
out = lax.add(amax, lax.log1p(lax.exp(delta)))
return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
return lax_other.logaddexp(x1, x2)


def _wrap_between(x, _a):
Expand All @@ -884,17 +894,6 @@ def _wrap_between(x, _a):
return lax.sub(rem, a)


@logaddexp.defjvp
def _logaddexp_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
x1, x2, t1, t2 = promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2)
primal_out = logaddexp(x1, x2)
tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
return primal_out, tangent_out


@custom_jvp
@implements(np.logaddexp2, module='numpy')
@jit
Expand Down

0 comments on commit 1cba097

Please sign in to comment.