Skip to content

Commit

Permalink
Merge pull request #23238 from quattro:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 669373295
  • Loading branch information
jax authors committed Aug 30, 2024
2 parents 8ccc439 + 7087d0c commit f8a4662
Showing 1 changed file with 27 additions and 23 deletions.
50 changes: 27 additions & 23 deletions jax/experimental/sparse/bcoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,12 +738,11 @@ def result(out_array, lhs_data, lhs_indices, rhs):
@bcoo_dot_general_p.def_abstract_eval
def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_numbers,
preferred_element_type, lhs_spinfo: SparseInfo):
out_aval = jax.eval_shape(
partial(lax.dot_general,
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type),
jax.ShapeDtypeStruct(lhs_spinfo.shape, lhs_data.dtype),
jax.ShapeDtypeStruct(rhs.shape, rhs.dtype))
out_aval = jax.jit(lax.dot_general, static_argnames=("dimension_numbers", "preferred_element_type")).eval_shape(
jax.ShapeDtypeStruct(lhs_spinfo.shape, lhs_data.dtype),
jax.ShapeDtypeStruct(rhs.shape, rhs.dtype),
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type)

(lhs_contracting, _), (lhs_batch, _) = dimension_numbers
n_batch, n_sparse, _, _ = _validate_bcoo(lhs_data, lhs_indices, lhs_spinfo.shape)
Expand Down Expand Up @@ -1187,12 +1186,11 @@ def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indic
dimension_numbers, preferred_element_type):
lhs_shape = lhs_spinfo.shape
rhs_shape = rhs_spinfo.shape
out_aval = jax.eval_shape(
partial(lax.dot_general,
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type),
jax.ShapeDtypeStruct(lhs_shape, lhs_data.dtype),
jax.ShapeDtypeStruct(rhs_shape, rhs_data.dtype))
out_aval = jax.jit(lax.dot_general, static_argnames=("dimension_numbers", "preferred_element_type")).eval_shape(
jax.ShapeDtypeStruct(lhs_shape, lhs_data.dtype),
jax.ShapeDtypeStruct(rhs_shape, rhs_data.dtype),
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type)

lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape)
rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape)
Expand Down Expand Up @@ -1773,9 +1771,9 @@ def bcoo_concatenate(operands: Sequence[BCOO], *, dimension: int) -> BCOO:
raise ValueError("bcoo_concatenate: expected operands to be a sequence of BCOO arrays. "
f"Got {operands}")
# Validate inputs using lax.concatenate abstract evaluation.
out_aval = jax.eval_shape(
functools.partial(lax.concatenate, dimension=dimension),
[core.ShapedArray(op.shape, op.dtype) for op in operands])
out_aval = jax.jit(lax.concatenate, static_argnames=("dimension",)).eval_shape(
[core.ShapedArray(op.shape, op.dtype) for op in operands],
dimension=dimension)
if len({op.n_dense for op in operands}) > 1:
raise ValueError("bcoo_concatenate requires inputs to have matching nse dimensions.")

Expand Down Expand Up @@ -1891,8 +1889,9 @@ def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], dimensions: Sequence[in
def bcoo_rev(operand, dimensions):
"""Sparse implementation of {func}`jax.lax.rev`"""
# Check validity of dimensions via original implementation.
_ = jax.eval_shape(partial(lax.rev, dimensions=dimensions),
jax.ShapeDtypeStruct(operand.shape, operand.dtype))
_ = jax.jit(lax.rev, static_argnames=("dimensions",)).eval_shape(
jax.ShapeDtypeStruct(operand.shape, operand.dtype),
dimensions=dimensions)
batch_dims = [d for d in dimensions if d < operand.n_batch]
sparse_dims = [d for d in dimensions if operand.n_batch <= d < operand.n_batch + operand.n_sparse]
dense_dims = [d for d in dimensions if d >= operand.n_batch + operand.n_sparse]
Expand Down Expand Up @@ -2035,15 +2034,16 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq
Returns:
out: BCOO array containing the slice.
"""
slice_sizes = tuple(operator.index(i) for i in slice_sizes)
# Use abstract eval to validate inputs.
jax.eval_shape(partial(lax.dynamic_slice, slice_sizes=slice_sizes),
jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices)
jax.jit(lax.dynamic_slice, static_argnames=("slice_sizes",)).eval_shape(
jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices,
slice_sizes=slice_sizes)
if not isinstance(mat, BCOO):
raise TypeError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}")
start_indices = tuple(jnp.asarray(i) for i in start_indices)
assert all(jnp.issubdtype(i.dtype, np.integer) for i in start_indices)
assert all(i.shape == () for i in start_indices)
slice_sizes = tuple(operator.index(i) for i in slice_sizes)
if len(start_indices) != len(slice_sizes) != mat.ndim:
raise ValueError(f"bcoo_dynamic_slice: indices must have size mat.ndim={mat.ndim}")
if not all(0 <= slice_size <= axis_size for slice_size, axis_size in zip(slice_sizes, mat.shape)):
Expand Down Expand Up @@ -2303,9 +2303,13 @@ def bcoo_gather(operand: BCOO, start_indices: Array,
mode=mode, fill_value=fill_value)

# Abstract eval lax.gather to validate arguments & determine output shape.
out_aval = jax.eval_shape(partial(lax.gather, **kwds),
jax.ShapeDtypeStruct(operand.shape, operand.dtype),
jax.ShapeDtypeStruct(start_indices.shape, start_indices.dtype))
static_argnames = ("dimension_numbers", "slice_sizes", "unique_indices",
"indices_are_sorted", "mode", "fill_value",)
out_aval = jax.jit(lax.gather, static_argnames=static_argnames).eval_shape(
jax.ShapeDtypeStruct(operand.shape, operand.dtype),
jax.ShapeDtypeStruct(start_indices.shape, start_indices.dtype),
**kwds)

offset_dims = dimension_numbers.offset_dims
collapsed_slice_dims = dimension_numbers.collapsed_slice_dims
start_index_map = dimension_numbers.start_index_map
Expand Down

0 comments on commit f8a4662

Please sign in to comment.