Skip to content

Commit

Permalink
[flax] Add QK-normalization to MultiHeadDotProductAttention
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 554507242
  • Loading branch information
prazek authored and Flax Authors committed Aug 12, 2023
1 parent c8424b0 commit 510c77f
Showing 1 changed file with 49 additions and 51 deletions.
100 changes: 49 additions & 51 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,24 @@
"""Attention core modules for Flax."""

import functools
from typing import (Any, Callable, Optional, Tuple, Union)
from flax.linen.dtypes import promote_dtype
from typing import Any, Callable, Optional, Tuple

from flax.linen import initializers
from flax.linen.dtypes import promote_dtype
from flax.linen.linear import default_kernel_init
from flax.linen.linear import DenseGeneral
from flax.linen.linear import DotGeneralT
from flax.linen.linear import PrecisionLike
from flax.linen.module import compact
from flax.linen.module import merge_param
from flax.linen.module import Module

from flax.linen.normalization import LayerNorm
import jax
from jax import lax
from jax import random
import jax.numpy as jnp


PRNGKey = Any
Shape = Tuple[int, ...]
Dtype = Any
Expand All @@ -57,19 +58,17 @@ def dot_product_attention_weights(
you can directly call this function and call einsum yourself.
Args:
query: queries for calculating attention with shape of
`[batch..., q_length, num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of
`[batch..., kv_length, num_heads, qk_depth_per_head]`.
query: queries for calculating attention with shape of `[batch..., q_length,
num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch..., kv_length,
num_heads, qk_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]`.
This can be used for incorporating causal masks, padding masks,
proximity bias, etc.
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
mask: mask for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]`.
This can be used for incorporating causal masks.
Attention weights are masked out if their corresponding mask value
is `False`.
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
incorporating causal masks. Attention weights are masked out if their
corresponding mask value is `False`.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
Expand Down Expand Up @@ -145,21 +144,19 @@ def dot_product_attention(
Note: query, key, value needn't have any batch dimensions.
Args:
query: queries for calculating attention with shape of
`[batch..., q_length, num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of
`[batch..., kv_length, num_heads, qk_depth_per_head]`.
value: values to be used in attention with shape of
`[batch..., kv_length, num_heads, v_depth_per_head]`.
query: queries for calculating attention with shape of `[batch..., q_length,
num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch..., kv_length,
num_heads, qk_depth_per_head]`.
value: values to be used in attention with shape of `[batch..., kv_length,
num_heads, v_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]`.
This can be used for incorporating causal masks, padding masks,
proximity bias, etc.
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
mask: mask for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]`.
This can be used for incorporating causal masks.
Attention weights are masked out if their corresponding mask value
is `False`.
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
incorporating causal masks. Attention weights are masked out if their
corresponding mask value is `False`.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
Expand Down Expand Up @@ -225,6 +222,7 @@ class MultiHeadDotProductAttention(Module):
key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,,
num_heads, value_channels]``
decode: whether to prepare and use an autoregressive cache.
normalize_qk: should QK normalization be applied (arxiv.org/abs/2302.05442).
"""

num_heads: int
Expand All @@ -243,6 +241,7 @@ class MultiHeadDotProductAttention(Module):
use_bias: bool = True
attention_fn: Callable[..., Array] = dot_product_attention
decode: bool = False
normalize_qk: bool = False
qkv_dot_general: DotGeneralT = lax.dot_general
out_dot_general: DotGeneralT = lax.dot_general

Expand All @@ -260,17 +259,13 @@ def __call__(
applies dot-product attention and project the results to an output vector.
Args:
inputs_q: input queries of shape
`[batch_sizes..., length, features]`.
inputs_kv: key/values of shape
`[batch_sizes..., length, features]`.
mask: attention mask of shape
`[batch_sizes..., num_heads, query_length, key/value_length]`.
Attention weights are masked out if their corresponding mask value
is `False`.
deterministic: if false, the attention weight is masked randomly
using dropout, whereas if true, the attention weights
are deterministic.
inputs_q: input queries of shape `[batch_sizes..., length, features]`.
inputs_kv: key/values of shape `[batch_sizes..., length, features]`.
mask: attention mask of shape `[batch_sizes..., num_heads, query_length,
key/value_length]`. Attention weights are masked out if their
corresponding mask value is `False`.
deterministic: if false, the attention weight is masked randomly using
dropout, whereas if true, the attention weights are deterministic.
Returns:
output of shape `[batch_sizes..., length, features]`.
Expand Down Expand Up @@ -303,6 +298,12 @@ def __call__(
dense(name='value')(inputs_kv),
)

if self.normalize_qk:
# Normalizing query and key projections stabilizes training with higher
# LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis.
query = LayerNorm(name='query_ln', use_bias=False)(query) # type: ignore[call-arg]
key = LayerNorm(name='key_ln', use_bias=False)(key) # type: ignore[call-arg]

# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.decode:
Expand Down Expand Up @@ -413,15 +414,12 @@ def __call__( # type: ignore
applies dot-product attention and project the results to an output vector.
Args:
inputs_q: input queries of shape
`[batch_sizes..., length, features]`.
mask: attention mask of shape
`[batch_sizes..., num_heads, query_length, key/value_length]`.
Attention weights are masked out if their corresponding mask value
is `False`.
deterministic: if false, the attention weight is masked randomly
using dropout, whereas if true, the attention weights
are deterministic.
inputs_q: input queries of shape `[batch_sizes..., length, features]`.
mask: attention mask of shape `[batch_sizes..., num_heads, query_length,
key/value_length]`. Attention weights are masked out if their
corresponding mask value is `False`.
deterministic: if false, the attention weight is masked randomly using
dropout, whereas if true, the attention weights are deterministic.
Returns:
output of shape `[batch_sizes..., length, features]`.
Expand Down Expand Up @@ -451,8 +449,8 @@ def make_attention_mask(
query_input: a batched, flat input of query_length size
key_input: a batched, flat input of key_length size
pairwise_fn: broadcasting elementwise comparison function
extra_batch_dims: number of extra batch dims to add singleton
axes for, none by default
extra_batch_dims: number of extra batch dims to add singleton axes for, none
by default
dtype: mask return dtype
Returns:
Expand All @@ -477,8 +475,8 @@ def make_causal_mask(
Args:
x: input array of shape `[batch..., len]`
extra_batch_dims: number of batch dims to add singleton axes for,
none by default
extra_batch_dims: number of batch dims to add singleton axes for, none by
default
dtype: mask return dtype
Returns:
Expand Down

0 comments on commit 510c77f

Please sign in to comment.