Skip to content

Commit

Permalink
fixes cache miss and addresses static arg in BCOO dynamic slice
Browse files Browse the repository at this point in the history
  • Loading branch information
quattro committed Aug 26, 2024
1 parent 14c719e commit 7087d0c
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/experimental/sparse/bcoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2034,6 +2034,7 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq
Returns:
out: BCOO array containing the slice.
"""
slice_sizes = tuple(operator.index(i) for i in slice_sizes)
# Use abstract eval to validate inputs.
jax.jit(lax.dynamic_slice, static_argnames=("slice_sizes",)).eval_shape(
jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices,
Expand All @@ -2043,7 +2044,6 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq
start_indices = tuple(jnp.asarray(i) for i in start_indices)
assert all(jnp.issubdtype(i.dtype, np.integer) for i in start_indices)
assert all(i.shape == () for i in start_indices)
slice_sizes = tuple(operator.index(i) for i in slice_sizes)
if len(start_indices) != len(slice_sizes) != mat.ndim:
raise ValueError(f"bcoo_dynamic_slice: indices must have size mat.ndim={mat.ndim}")
if not all(0 <= slice_size <= axis_size for slice_size, axis_size in zip(slice_sizes, mat.shape)):
Expand Down

0 comments on commit 7087d0c

Please sign in to comment.