From ffa53b5f050034ae5e069782f702d621d7b45f20 Mon Sep 17 00:00:00 2001 From: quattro Date: Mon, 26 Aug 2024 09:41:40 -0700 Subject: [PATCH 1/3] fixes cache miss in abstract_eval_shape for bcoo dot general --- jax/experimental/sparse/bcoo.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 4cbe52383751..20917c3f7152 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -738,12 +738,11 @@ def result(out_array, lhs_data, lhs_indices, rhs): @bcoo_dot_general_p.def_abstract_eval def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_numbers, preferred_element_type, lhs_spinfo: SparseInfo): - out_aval = jax.eval_shape( - partial(lax.dot_general, - dimension_numbers=dimension_numbers, - preferred_element_type=preferred_element_type), - jax.ShapeDtypeStruct(lhs_spinfo.shape, lhs_data.dtype), - jax.ShapeDtypeStruct(rhs.shape, rhs.dtype)) + out_aval = jax.jit(lax.dot_general, static_argnames=("dimension_numbers", "preferred_element_type")).eval_shape( + jax.ShapeDtypeStruct(lhs_spinfo.shape, lhs_data.dtype), + jax.ShapeDtypeStruct(rhs.shape, rhs.dtype), + dimension_numbers=dimension_numbers, + preferred_element_type=preferred_element_type) (lhs_contracting, _), (lhs_batch, _) = dimension_numbers n_batch, n_sparse, _, _ = _validate_bcoo(lhs_data, lhs_indices, lhs_spinfo.shape) From 14c719e810f58a721a1ba3f626439767f7a965fb Mon Sep 17 00:00:00 2001 From: quattro Date: Mon, 26 Aug 2024 13:33:42 -0700 Subject: [PATCH 2/3] fixes cache miss in for eval shape in BCOO related functions --- jax/experimental/sparse/bcoo.py | 37 +++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 20917c3f7152..a0a4df4d898f 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -1186,12 +1186,11 @@ def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indic dimension_numbers, preferred_element_type): lhs_shape = lhs_spinfo.shape rhs_shape = rhs_spinfo.shape - out_aval = jax.eval_shape( - partial(lax.dot_general, - dimension_numbers=dimension_numbers, - preferred_element_type=preferred_element_type), - jax.ShapeDtypeStruct(lhs_shape, lhs_data.dtype), - jax.ShapeDtypeStruct(rhs_shape, rhs_data.dtype)) + out_aval = jax.jit(lax.dot_general, static_argnames=("dimension_numbers", "preferred_element_type")).eval_shape( + jax.ShapeDtypeStruct(lhs_shape, lhs_data.dtype), + jax.ShapeDtypeStruct(rhs_shape, rhs_data.dtype), + dimension_numbers=dimension_numbers, + preferred_element_type=preferred_element_type) lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape) rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape) @@ -1772,9 +1771,9 @@ def bcoo_concatenate(operands: Sequence[BCOO], *, dimension: int) -> BCOO: raise ValueError("bcoo_concatenate: expected operands to be a sequence of BCOO arrays. " f"Got {operands}") # Validate inputs using lax.concatenate abstract evaluation. - out_aval = jax.eval_shape( - functools.partial(lax.concatenate, dimension=dimension), - [core.ShapedArray(op.shape, op.dtype) for op in operands]) + out_aval = jax.jit(lax.concatenate, static_argnames=("dimension",)).eval_shape( + [core.ShapedArray(op.shape, op.dtype) for op in operands], + dimension=dimension) if len({op.n_dense for op in operands}) > 1: raise ValueError("bcoo_concatenate requires inputs to have matching nse dimensions.") @@ -1890,8 +1889,9 @@ def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], dimensions: Sequence[in def bcoo_rev(operand, dimensions): """Sparse implementation of {func}`jax.lax.rev`""" # Check validity of dimensions via original implementation. - _ = jax.eval_shape(partial(lax.rev, dimensions=dimensions), - jax.ShapeDtypeStruct(operand.shape, operand.dtype)) + _ = jax.jit(lax.rev, static_argnames=("dimensions",)).eval_shape( + jax.ShapeDtypeStruct(operand.shape, operand.dtype), + dimensions=dimensions) batch_dims = [d for d in dimensions if d < operand.n_batch] sparse_dims = [d for d in dimensions if operand.n_batch <= d < operand.n_batch + operand.n_sparse] dense_dims = [d for d in dimensions if d >= operand.n_batch + operand.n_sparse] @@ -2035,8 +2035,9 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq out: BCOO array containing the slice. """ # Use abstract eval to validate inputs. - jax.eval_shape(partial(lax.dynamic_slice, slice_sizes=slice_sizes), - jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices) + jax.jit(lax.dynamic_slice, static_argnames=("slice_sizes",)).eval_shape( + jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices, + slice_sizes=slice_sizes) if not isinstance(mat, BCOO): raise TypeError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}") start_indices = tuple(jnp.asarray(i) for i in start_indices) @@ -2302,9 +2303,13 @@ def bcoo_gather(operand: BCOO, start_indices: Array, mode=mode, fill_value=fill_value) # Abstract eval lax.gather to validate arguments & determine output shape. - out_aval = jax.eval_shape(partial(lax.gather, **kwds), - jax.ShapeDtypeStruct(operand.shape, operand.dtype), - jax.ShapeDtypeStruct(start_indices.shape, start_indices.dtype)) + static_argnames = ("dimension_numbers", "slice_sizes", "unique_indices", + "indices_are_sorted", "mode", "fill_value",) + out_aval = jax.jit(lax.gather, static_argnames=static_argnames).eval_shape( + jax.ShapeDtypeStruct(operand.shape, operand.dtype), + jax.ShapeDtypeStruct(start_indices.shape, start_indices.dtype), + **kwds) + offset_dims = dimension_numbers.offset_dims collapsed_slice_dims = dimension_numbers.collapsed_slice_dims start_index_map = dimension_numbers.start_index_map From 7087d0cf081b29d5a1938fe354289dba051e1e41 Mon Sep 17 00:00:00 2001 From: quattro Date: Mon, 26 Aug 2024 16:45:34 -0700 Subject: [PATCH 3/3] fixes cache miss and addresses static arg in BCOO dynamic slice --- jax/experimental/sparse/bcoo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index a0a4df4d898f..9eafa0db0fc2 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -2034,6 +2034,7 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq Returns: out: BCOO array containing the slice. """ + slice_sizes = tuple(operator.index(i) for i in slice_sizes) # Use abstract eval to validate inputs. jax.jit(lax.dynamic_slice, static_argnames=("slice_sizes",)).eval_shape( jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices, @@ -2043,7 +2044,6 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq start_indices = tuple(jnp.asarray(i) for i in start_indices) assert all(jnp.issubdtype(i.dtype, np.integer) for i in start_indices) assert all(i.shape == () for i in start_indices) - slice_sizes = tuple(operator.index(i) for i in slice_sizes) if len(start_indices) != len(slice_sizes) != mat.ndim: raise ValueError(f"bcoo_dynamic_slice: indices must have size mat.ndim={mat.ndim}") if not all(0 <= slice_size <= axis_size for slice_size, axis_size in zip(slice_sizes, mat.shape)):