diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 931587c0cfcf..171954bc86c5 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -646,7 +646,8 @@ def _dot_product_attention_fwd_batcher( outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - mask_type=mask_type, layout=layout, is_training=is_training) + mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length, is_training=is_training) # reshape to original shape output = outputs[0] @@ -698,6 +699,7 @@ def _dot_product_attention_bwd_batcher( fwd_output, grad_output, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length, ) # reshape to original shape