Skip to content

Commit

Permalink
Avoid unnecessary fori_loop when calculating the block indices.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621580324
  • Loading branch information
voutcn authored and jax authors committed Apr 3, 2024
1 parent 85cb169 commit 9bb3f79
Showing 1 changed file with 28 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,35 @@ def paged_flash_attention_kernel(
b = b * b_step + b_start
length = lengths_ref[b]

def advance_to_next_non_zero_length(b):
return lax.fori_loop(
lax.div(b, b_step),
lax.div(batch_size, b_step),
lambda _, b: jnp.where(lengths_ref[b] == 0, b + b_step, b),
b,
)

def compute_block_indices(b, h, i):
length = lengths_ref[b]
not_done = i * bk < length
i_next = jnp.where(not_done, i, 0)
h_next = jnp.where(not_done, h, h + h_step)
is_last_head = h_next >= num_kv_heads
h_next = jnp.where(is_last_head, h_start, h_next)
b_next = jnp.where(
is_last_head, advance_to_next_non_zero_length(b + b_step), b
)
return b_next, h_next, i_next

def advance_b():
next_b = b + b_step

def advance_to_next_non_zero_length():
next_next_b = next_b + b_step
return lax.fori_loop(
lax.div(next_next_b, b_step),
lax.div(batch_size, b_step),
lambda _, b: jnp.where(lengths_ref[b] == 0, b + b_step, b),
next_next_b,
)

return (
lax.cond(
jnp.logical_and(next_b < batch_size, lengths_ref[next_b] == 0),
advance_to_next_non_zero_length,
lambda: next_b,
),
h_start,
0,
)

def advance_h():
next_h = h + h_step
return lax.cond(next_h < num_kv_heads, lambda: (b, next_h, 0), advance_b)

return lax.cond(i * bk < lengths_ref[b], lambda: (b, h, i), advance_h)

def create_kv_async_copy_descriptors(b, h, i, buffer_index):
page_offset = b * pages_per_sequence + i * pages_per_compute_block
Expand Down

0 comments on commit 9bb3f79

Please sign in to comment.