Skip to content

Commit

Permalink
Merge pull request #23247 from kaixih:sliding_window_attn
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676079831
  • Loading branch information
Google-ML-Automation committed Sep 18, 2024
2 parents cd04d0f + 541b3a3 commit b164d67
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 15 deletions.
58 changes: 49 additions & 9 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,14 @@ def _get_causal_mask(T, S):
mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
return mask[None, None, :, :]

def _get_window_mask(T: int, S: int, local_window_size: tuple[int, int]):
query_pos = jnp.array(range(T))
key_pos = jnp.array(range(S))
left_window, right_window = local_window_size
left_mask = query_pos[..., None] <= key_pos[..., None, :] + left_window
right_mask = query_pos[..., None] >= key_pos[..., None, :] - right_window
return jnp.logical_and(right_mask, left_mask)[None, None, :, :]

def _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen):
q_mask = True
kv_mask = True
Expand All @@ -802,7 +810,8 @@ def _get_padding_mask_encoded(T, q_seqlen):
mask = q_indices < q_seqlen[:, None]
return mask[:, :, None, None]

def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen):
def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
local_window_size):
if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None:
return logits

Expand All @@ -817,6 +826,10 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen):
mask = _get_causal_mask(T, S)
combined_mask = jnp.logical_and(combined_mask, mask)

if local_window_size is not None:
mask = _get_window_mask(T, S, local_window_size)
combined_mask = jnp.logical_and(combined_mask, mask)

if q_seqlen is not None or kv_seqlen is not None:
mask = _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen)
combined_mask = jnp.logical_and(combined_mask, mask)
Expand All @@ -826,7 +839,7 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen):
return padded_logits

def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
scale, q_seqlen, kv_seqlen):
scale, q_seqlen, kv_seqlen, local_window_size):
logits_dtype = jnp.promote_types(query.dtype, jnp.float32)
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key,
preferred_element_type=logits_dtype)
Expand All @@ -836,7 +849,8 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
if bias is not None:
logits = (logits + bias).astype(logits.dtype)

padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen)
padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
local_window_size)

# Softmax and it is always carried out in fp32.
padded_logits = padded_logits.astype(jnp.float32)
Expand All @@ -857,7 +871,8 @@ def _dot_product_attention_xla(
is_causal: bool,
scale: float,
q_seqlen: Array | None,
kv_seqlen: Array | None):
kv_seqlen: Array | None,
local_window_size: tuple[int, int] | None):

B, T, N, H = query.shape
_, S, K, _ = key.shape
Expand All @@ -875,11 +890,13 @@ def _reshape_to_grouped(t):
return t
bias = _reshape_to_grouped(bias)
mask = _reshape_to_grouped(mask)
vmapped_fn = jax.vmap(_dot_product_attention_core,
in_axes=(3, None, None, 2, 2, None, None, None, None),
out_axes=3)
vmapped_fn = jax.vmap(
_dot_product_attention_core,
in_axes=(3, None, None, 2, 2, None, None, None, None, None),
out_axes=3,
)
encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale,
q_seqlen, kv_seqlen)
q_seqlen, kv_seqlen, local_window_size)
encoded = jnp.reshape(encoded, (B, T, N, H))
return encoded

Expand All @@ -894,6 +911,7 @@ def dot_product_attention(
is_causal: bool = False,
query_seq_lengths: ArrayLike | None = None,
key_value_seq_lengths: ArrayLike | None = None,
local_window_size: int | tuple[int, int] | None = None,
implementation: Literal['xla', 'cudnn'] | None = None) -> Array:
r"""Scaled dot product attention function.
Expand Down Expand Up @@ -943,6 +961,12 @@ def dot_product_attention(
:code:`(B)`
key_value_seq_lengths: `int32` array of sequence lengths for key and value;
shape :code:`(B)`
local_window_size: Window sizes to make self attention to attend to each
token's local window. If set, this specifies the (left_window_size,
right_window_size) for each token. E.g., if local_window_size == (3, 2)
and the sequence is [0, 1, 2, 3, 4, 5, c, 7, 8, 9], token `c` can attend
to [3, 4, 5, c, 7, 8]. If a single int is given, it will be intepreted as
a symmetric window (window_size, window_size).
implementation: A string to control which implementation backend to use.
Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults
to `None`, which will automatically select the best available backend.
Expand All @@ -969,6 +993,8 @@ def _ensure_4d(t):
query_seq_lengths = jnp.asarray(query_seq_lengths)
if key_value_seq_lengths is not None:
key_value_seq_lengths = jnp.asarray(key_value_seq_lengths)
if isinstance(local_window_size, int):
local_window_size = (local_window_size, local_window_size)

def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
dtype: DType | None, name: str) -> None:
Expand Down Expand Up @@ -1003,6 +1029,7 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal,
scale=scale_val, q_seqlen=query_seq_lengths,
kv_seqlen=key_value_seq_lengths,
local_window_size=local_window_size,
)
case 'cudnn':
use_padding = (
Expand All @@ -1022,9 +1049,21 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
mask_type = MaskType.CAUSAL
elif use_padding:
mask_type = MaskType.PADDING
# CuDNN supports only the left window with an exclusive boundary when
# causal mask is enabled.
sliding_window = None
if local_window_size is not None:
l_window, r_window = local_window_size
if r_window == 0 or mask_type == MaskType.CAUSAL:
sliding_window = l_window + 1
else:
raise ValueError(f"cuDNN doesn't support right window: {r_window} "
"when causal mask is not used.")

out = cudnn_dot_product_attention(
query_arr, key_arr, value_arr, bias, mask, query_seq_lengths,
key_value_seq_lengths, scale=scale_val, mask_type=mask_type
key_value_seq_lengths, scale=scale_val, mask_type=mask_type,
sliding_window_length=sliding_window,
)
case None:
# TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select
Expand All @@ -1033,6 +1072,7 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal,
scale=scale_val, q_seqlen=query_seq_lengths,
kv_seqlen=key_value_seq_lengths,
local_window_size=local_window_size,
)
case _:
raise ValueError(f"Unsupported implementation option: {implementation}")
Expand Down
19 changes: 13 additions & 6 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@

config.parse_flags_with_absl()

def _is_required_cudnn_version_satisfied():
def _is_required_cudnn_version_satisfied(min_cudnn_version):
return (
jtu.is_cuda_compute_capability_at_least("8.0") and
cuda_versions is not None and
cuda_versions.cudnn_get_version() >= 8904
cuda_versions.cudnn_get_version() >= min_cudnn_version
)

def _check_cudnn_backend(fn, *args, **kwargs):
Expand All @@ -60,7 +60,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
impl=['cudnn', 'xla'],
)
def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(8904):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
Expand Down Expand Up @@ -102,13 +102,15 @@ def testDotProductAttention(self, dtype, group_num, use_vmap, impl):

@parameterized.product(
mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'),
('custom', 'padding'), ('bias', 'causal')],
('custom', 'padding'), ('bias', 'causal'),
('causal', 'sliding_window')],
)
def testDotProductAttentionMask(self, mask_mode):
if not _is_required_cudnn_version_satisfied():
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if isinstance(mask_mode, str):
mask_mode = (mask_mode,)
min_cudnn_version = 90200 if 'sliding_window' in mask_mode else 8904
if not _is_required_cudnn_version_satisfied(min_cudnn_version):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")

dtype = jnp.bfloat16
B, S, T, N, H = 2, 128, 128, 4, 32
Expand All @@ -119,6 +121,7 @@ def testDotProductAttentionMask(self, mask_mode):
grad = random.normal(keys[3], (B, T, N, H), dtype)
bias, mask = None, None
q_seqlen, kv_seqlen = None, None
window_size = None

is_causal = 'causal' in mask_mode
if 'padding' in mask_mode:
Expand All @@ -130,6 +133,8 @@ def testDotProductAttentionMask(self, mask_mode):
mask = custom_mask[None, None, :, :]
if 'bias' in mask_mode:
bias = random.normal(keys[4], (1, N, T, S), dtype)
if 'sliding_window' in mask_mode:
window_size = (3, 2) if is_causal else (3, 0)

sdpa = nn.dot_product_attention
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
Expand All @@ -141,9 +146,11 @@ def testDotProductAttentionMask(self, mask_mode):
# Convert the kargs to positional args for the jax.vjp.
fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref(
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs,
local_window_size=window_size,
)
fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans(
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs,
local_window_size=window_size,
)
out_ref, sdpa_vjp_ref = jax.vjp(fn_ref, *args, q_seqlen, kv_seqlen)
out_ans, sdpa_vjp_ans = jax.vjp(fn_ans, *args, q_seqlen, kv_seqlen)
Expand Down

0 comments on commit b164d67

Please sign in to comment.