Skip to content

Commit

Permalink
[Pallas] Add support for pytrees in scalar prefetch
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570453699
  • Loading branch information
sharadmv authored and jax authors committed Oct 3, 2023
1 parent ee8af09 commit 24ad445
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
19 changes: 11 additions & 8 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import functools
from typing import Any, Callable, Iterator

from jax._src import api_util
from jax._src import core as jax_core
from jax._src import linear_util as lu
from jax._src import state
Expand Down Expand Up @@ -136,20 +137,20 @@ def _preprocess_grid(grid: Grid | int | None) -> Grid:

def _convert_block_spec_to_block_mapping(
in_avals: list[jax_core.ShapedArray], block_spec: BlockSpec | None,
aval: jax_core.ShapedArray,
aval: jax_core.ShapedArray, in_tree: Any,
) -> BlockSpec | None:
if block_spec is no_block_spec:
return None
if block_spec.index_map is None:
compute_index = lambda *args: (0,) * len(aval.shape)
compute_index = lambda *args, **kwargs: (0,) * len(aval.shape)
block_shape = aval.shape
else:
compute_index = block_spec.compute_index
block_shape = block_spec.block_shape
block_shape = tuple(
mapped if s is None else s for s in block_shape)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(compute_index), in_avals)
flat_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index), in_tree)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts),
block_spec.memory_space)

Expand Down Expand Up @@ -249,12 +250,14 @@ def get_grid_mapping(
self.grid, in_avals, flat_in_specs, out_avals,
flat_out_specs)
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
# Create args, kwargs pytree def
grid_tree = tree_util.tree_structure((tuple(grid_avals), {}))
in_block_mappings = map(
partial(_convert_block_spec_to_block_mapping, grid_avals), in_specs,
in_ref_avals)
partial(_convert_block_spec_to_block_mapping, grid_avals,
in_tree=grid_tree), in_specs, in_ref_avals)
out_block_mappings = map(
partial(_convert_block_spec_to_block_mapping, grid_avals), out_specs,
out_ref_avals)
partial(_convert_block_spec_to_block_mapping, grid_avals,
in_tree=grid_tree), out_specs, out_ref_avals)
grid_mapping = GridMapping(
self.grid, (*in_block_mappings, *out_block_mappings), (),
num_index_operands=0)
Expand Down
11 changes: 8 additions & 3 deletions jax/_src/pallas/mosaic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
in_specs_tree: Any
out_specs_tree: Any


def __init__(
self,
num_scalar_prefetch: int,
Expand Down Expand Up @@ -160,12 +159,18 @@ def get_grid_mapping(
state.shaped_array_ref(aval.shape, aval.dtype)
for aval in flat_scalar_avals]
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
# Create args, kwargs pytree def
index_map_in_tree = tree_util.tree_structure(
((*grid_avals, *scalar_avals), {})
)
in_block_mappings = map(
partial(_convert_block_spec_to_block_mapping,
(*grid_avals, *scalar_ref_avals)), in_specs, in_ref_avals)
(*grid_avals, *scalar_ref_avals),
in_tree=index_map_in_tree), in_specs, in_ref_avals)
out_block_mappings = map(
partial(_convert_block_spec_to_block_mapping,
(*grid_avals, *scalar_ref_avals)), out_specs, out_ref_avals)
(*grid_avals, *scalar_ref_avals),
in_tree=index_map_in_tree), out_specs, out_ref_avals)
grid_mapping = GridMapping(
grid=self.grid,
block_mappings=(*in_block_mappings, *out_block_mappings),
Expand Down

0 comments on commit 24ad445

Please sign in to comment.