From 1966b5791f5a678d70a9fb35380ea0af8a47bfd8 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 17 Jul 2023 10:24:58 -0700 Subject: [PATCH] Fix vmap rules when num_index_operands > 0 PiperOrigin-RevId: 548730993 --- jax_triton/pallas/pallas_call.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/jax_triton/pallas/pallas_call.py b/jax_triton/pallas/pallas_call.py index 15076e10..c1faa129 100644 --- a/jax_triton/pallas/pallas_call.py +++ b/jax_triton/pallas/pallas_call.py @@ -215,7 +215,11 @@ def _block_map_function(new_idx, *args): if dim is not batching.not_mapped: indices.insert(dim, new_idx) return tuple(indices) - idx_avals = [jax_core.ShapedArray((), jnp.int32)] * (len(grid) + 1) + i32_aval = jax_core.ShapedArray((), jnp.int32) + if block_mapping is None: + idx_avals = [i32_aval] * (len(grid) + 1) + else: + idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals] block_mapping_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(_block_map_function), idx_avals) shape = aval.shape if block_mapping is None else block_mapping.block_shape @@ -271,8 +275,11 @@ def _pallas_call_batching_rule(args, dims, *, all_dims = list(dims) + [0] * len(out_shapes) - batched_block_mappings = map(partial(_batch_block_mapping, grid_mapping.grid), - avals, all_dims, block_mappings) + num_index_operands = grid_mapping.num_index_operands + batched_block_mappings = map( + partial(_batch_block_mapping, grid_mapping.grid), + avals[num_index_operands:], all_dims[num_index_operands:], block_mappings) + batched_in_shapes = tuple( jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else tuple_insert(x.shape, dim, axis_size),