Skip to content

Commit

Permalink
Fixed rules where sliding_window_length was not forwarded
Browse files Browse the repository at this point in the history
This is follow up to #23284.

PiperOrigin-RevId: 670531634
  • Loading branch information
superbobry authored and jax authors committed Sep 3, 2024
1 parent 7b161fb commit ccabd21
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion jax/_src/cudnn/fused_attention_stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,8 @@ def _dot_product_attention_fwd_batcher(
outputs = _dot_product_attention_fwd_p_wrapper.bind(
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout, is_training=is_training)
mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length, is_training=is_training)

# reshape to original shape
output = outputs[0]
Expand Down Expand Up @@ -698,6 +699,7 @@ def _dot_product_attention_bwd_batcher(
fwd_output, grad_output, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length,
)

# reshape to original shape
Expand Down

0 comments on commit ccabd21

Please sign in to comment.