From 541b3a3f7565b0e3f826b388dd094d22b28efb54 Mon Sep 17 00:00:00 2001 From: kaixih Date: Mon, 26 Aug 2024 17:32:38 +0000 Subject: [PATCH] New feature --- jax/_src/nn/functions.py | 58 +++++++++++++++++++++++++++++++++------- tests/nn_test.py | 19 ++++++++----- 2 files changed, 62 insertions(+), 15 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index a5b5aaf31799..c1f4831e5ec0 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -785,6 +785,14 @@ def _get_causal_mask(T, S): mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) return mask[None, None, :, :] +def _get_window_mask(T: int, S: int, local_window_size: tuple[int, int]): + query_pos = jnp.array(range(T)) + key_pos = jnp.array(range(S)) + left_window, right_window = local_window_size + left_mask = query_pos[..., None] <= key_pos[..., None, :] + left_window + right_mask = query_pos[..., None] >= key_pos[..., None, :] - right_window + return jnp.logical_and(right_mask, left_mask)[None, None, :, :] + def _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen): q_mask = True kv_mask = True @@ -802,7 +810,8 @@ def _get_padding_mask_encoded(T, q_seqlen): mask = q_indices < q_seqlen[:, None] return mask[:, :, None, None] -def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen): +def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen, + local_window_size): if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None: return logits @@ -817,6 +826,10 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen): mask = _get_causal_mask(T, S) combined_mask = jnp.logical_and(combined_mask, mask) + if local_window_size is not None: + mask = _get_window_mask(T, S, local_window_size) + combined_mask = jnp.logical_and(combined_mask, mask) + if q_seqlen is not None or kv_seqlen is not None: mask = _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen) combined_mask = jnp.logical_and(combined_mask, mask) @@ -826,7 +839,7 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen): return padded_logits def _dot_product_attention_core(query, key, value, bias, mask, is_causal, - scale, q_seqlen, kv_seqlen): + scale, q_seqlen, kv_seqlen, local_window_size): logits_dtype = jnp.promote_types(query.dtype, jnp.float32) logits = jnp.einsum('BTNH,BSNH->BNTS', query, key, preferred_element_type=logits_dtype) @@ -836,7 +849,8 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal, if bias is not None: logits = (logits + bias).astype(logits.dtype) - padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen) + padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen, + local_window_size) # Softmax and it is always carried out in fp32. padded_logits = padded_logits.astype(jnp.float32) @@ -857,7 +871,8 @@ def _dot_product_attention_xla( is_causal: bool, scale: float, q_seqlen: Array | None, - kv_seqlen: Array | None): + kv_seqlen: Array | None, + local_window_size: tuple[int, int] | None): B, T, N, H = query.shape _, S, K, _ = key.shape @@ -875,11 +890,13 @@ def _reshape_to_grouped(t): return t bias = _reshape_to_grouped(bias) mask = _reshape_to_grouped(mask) - vmapped_fn = jax.vmap(_dot_product_attention_core, - in_axes=(3, None, None, 2, 2, None, None, None, None), - out_axes=3) + vmapped_fn = jax.vmap( + _dot_product_attention_core, + in_axes=(3, None, None, 2, 2, None, None, None, None, None), + out_axes=3, + ) encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale, - q_seqlen, kv_seqlen) + q_seqlen, kv_seqlen, local_window_size) encoded = jnp.reshape(encoded, (B, T, N, H)) return encoded @@ -894,6 +911,7 @@ def dot_product_attention( is_causal: bool = False, query_seq_lengths: ArrayLike | None = None, key_value_seq_lengths: ArrayLike | None = None, + local_window_size: int | tuple[int, int] | None = None, implementation: Literal['xla', 'cudnn'] | None = None) -> Array: r"""Scaled dot product attention function. @@ -943,6 +961,12 @@ def dot_product_attention( :code:`(B)` key_value_seq_lengths: `int32` array of sequence lengths for key and value; shape :code:`(B)` + local_window_size: Window sizes to make self attention to attend to each + token's local window. If set, this specifies the (left_window_size, + right_window_size) for each token. E.g., if local_window_size == (3, 2) + and the sequence is [0, 1, 2, 3, 4, 5, c, 7, 8, 9], token `c` can attend + to [3, 4, 5, c, 7, 8]. If a single int is given, it will be intepreted as + a symmetric window (window_size, window_size). implementation: A string to control which implementation backend to use. Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults to `None`, which will automatically select the best available backend. @@ -969,6 +993,8 @@ def _ensure_4d(t): query_seq_lengths = jnp.asarray(query_seq_lengths) if key_value_seq_lengths is not None: key_value_seq_lengths = jnp.asarray(key_value_seq_lengths) + if isinstance(local_window_size, int): + local_window_size = (local_window_size, local_window_size) def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], dtype: DType | None, name: str) -> None: @@ -1003,6 +1029,7 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, scale=scale_val, q_seqlen=query_seq_lengths, kv_seqlen=key_value_seq_lengths, + local_window_size=local_window_size, ) case 'cudnn': use_padding = ( @@ -1022,9 +1049,21 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], mask_type = MaskType.CAUSAL elif use_padding: mask_type = MaskType.PADDING + # CuDNN supports only the left window with an exclusive boundary when + # causal mask is enabled. + sliding_window = None + if local_window_size is not None: + l_window, r_window = local_window_size + if r_window == 0 or mask_type == MaskType.CAUSAL: + sliding_window = l_window + 1 + else: + raise ValueError(f"cuDNN doesn't support right window: {r_window} " + "when causal mask is not used.") + out = cudnn_dot_product_attention( query_arr, key_arr, value_arr, bias, mask, query_seq_lengths, - key_value_seq_lengths, scale=scale_val, mask_type=mask_type + key_value_seq_lengths, scale=scale_val, mask_type=mask_type, + sliding_window_length=sliding_window, ) case None: # TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select @@ -1033,6 +1072,7 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, scale=scale_val, q_seqlen=query_seq_lengths, kv_seqlen=key_value_seq_lengths, + local_window_size=local_window_size, ) case _: raise ValueError(f"Unsupported implementation option: {implementation}") diff --git a/tests/nn_test.py b/tests/nn_test.py index 3722db42671c..be07de184e60 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -38,11 +38,11 @@ config.parse_flags_with_absl() -def _is_required_cudnn_version_satisfied(): +def _is_required_cudnn_version_satisfied(min_cudnn_version): return ( jtu.is_cuda_compute_capability_at_least("8.0") and cuda_versions is not None and - cuda_versions.cudnn_get_version() >= 8904 + cuda_versions.cudnn_get_version() >= min_cudnn_version ) def _check_cudnn_backend(fn, *args, **kwargs): @@ -60,7 +60,7 @@ class NNFunctionsTest(jtu.JaxTestCase): impl=['cudnn', 'xla'], ) def testDotProductAttention(self, dtype, group_num, use_vmap, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): + if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(8904): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if impl == 'cudnn' and dtype == jnp.float32: raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") @@ -102,13 +102,15 @@ def testDotProductAttention(self, dtype, group_num, use_vmap, impl): @parameterized.product( mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'), - ('custom', 'padding'), ('bias', 'causal')], + ('custom', 'padding'), ('bias', 'causal'), + ('causal', 'sliding_window')], ) def testDotProductAttentionMask(self, mask_mode): - if not _is_required_cudnn_version_satisfied(): - raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if isinstance(mask_mode, str): mask_mode = (mask_mode,) + min_cudnn_version = 90200 if 'sliding_window' in mask_mode else 8904 + if not _is_required_cudnn_version_satisfied(min_cudnn_version): + raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") dtype = jnp.bfloat16 B, S, T, N, H = 2, 128, 128, 4, 32 @@ -119,6 +121,7 @@ def testDotProductAttentionMask(self, mask_mode): grad = random.normal(keys[3], (B, T, N, H), dtype) bias, mask = None, None q_seqlen, kv_seqlen = None, None + window_size = None is_causal = 'causal' in mask_mode if 'padding' in mask_mode: @@ -130,6 +133,8 @@ def testDotProductAttentionMask(self, mask_mode): mask = custom_mask[None, None, :, :] if 'bias' in mask_mode: bias = random.normal(keys[4], (1, N, T, S), dtype) + if 'sliding_window' in mask_mode: + window_size = (3, 2) if is_causal else (3, 0) sdpa = nn.dot_product_attention sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) @@ -141,9 +146,11 @@ def testDotProductAttentionMask(self, mask_mode): # Convert the kargs to positional args for the jax.vjp. fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref( q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs, + local_window_size=window_size, ) fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans( q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs, + local_window_size=window_size, ) out_ref, sdpa_vjp_ref = jax.vjp(fn_ref, *args, q_seqlen, kv_seqlen) out_ans, sdpa_vjp_ans = jax.vjp(fn_ans, *args, q_seqlen, kv_seqlen)