From 6bce7204622ce90a295a9280efff6c5ec9d353ed Mon Sep 17 00:00:00 2001 From: Piotr Padlewski Date: Mon, 7 Aug 2023 09:54:49 -0700 Subject: [PATCH] [flax] Add QK-normalization to MultiHeadDotProductAttention PiperOrigin-RevId: 554507242 --- flax/linen/attention.py | 100 ++++++++++++++++++++-------------------- 1 file changed, 49 insertions(+), 51 deletions(-) diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 2c6400be07..5a86157379 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -15,10 +15,10 @@ """Attention core modules for Flax.""" import functools -from typing import (Any, Callable, Optional, Tuple, Union) -from flax.linen.dtypes import promote_dtype +from typing import Any, Callable, Optional, Tuple from flax.linen import initializers +from flax.linen.dtypes import promote_dtype from flax.linen.linear import default_kernel_init from flax.linen.linear import DenseGeneral from flax.linen.linear import DotGeneralT @@ -26,12 +26,13 @@ from flax.linen.module import compact from flax.linen.module import merge_param from flax.linen.module import Module - +from flax.linen.normalization import LayerNorm import jax from jax import lax from jax import random import jax.numpy as jnp + PRNGKey = Any Shape = Tuple[int, ...] Dtype = Any @@ -57,19 +58,17 @@ def dot_product_attention_weights( you can directly call this function and call einsum yourself. Args: - query: queries for calculating attention with shape of - `[batch..., q_length, num_heads, qk_depth_per_head]`. - key: keys for calculating attention with shape of - `[batch..., kv_length, num_heads, qk_depth_per_head]`. + query: queries for calculating attention with shape of `[batch..., q_length, + num_heads, qk_depth_per_head]`. + key: keys for calculating attention with shape of `[batch..., kv_length, + num_heads, qk_depth_per_head]`. bias: bias for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating causal masks, padding masks, - proximity bias, etc. + shape `[batch..., num_heads, q_length, kv_length]`. This can be used for + incorporating causal masks, padding masks, proximity bias, etc. mask: mask for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating causal masks. - Attention weights are masked out if their corresponding mask value - is `False`. + shape `[batch..., num_heads, q_length, kv_length]`. This can be used for + incorporating causal masks. Attention weights are masked out if their + corresponding mask value is `False`. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate @@ -145,21 +144,19 @@ def dot_product_attention( Note: query, key, value needn't have any batch dimensions. Args: - query: queries for calculating attention with shape of - `[batch..., q_length, num_heads, qk_depth_per_head]`. - key: keys for calculating attention with shape of - `[batch..., kv_length, num_heads, qk_depth_per_head]`. - value: values to be used in attention with shape of - `[batch..., kv_length, num_heads, v_depth_per_head]`. + query: queries for calculating attention with shape of `[batch..., q_length, + num_heads, qk_depth_per_head]`. + key: keys for calculating attention with shape of `[batch..., kv_length, + num_heads, qk_depth_per_head]`. + value: values to be used in attention with shape of `[batch..., kv_length, + num_heads, v_depth_per_head]`. bias: bias for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating causal masks, padding masks, - proximity bias, etc. + shape `[batch..., num_heads, q_length, kv_length]`. This can be used for + incorporating causal masks, padding masks, proximity bias, etc. mask: mask for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating causal masks. - Attention weights are masked out if their corresponding mask value - is `False`. + shape `[batch..., num_heads, q_length, kv_length]`. This can be used for + incorporating causal masks. Attention weights are masked out if their + corresponding mask value is `False`. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate @@ -225,6 +222,7 @@ class MultiHeadDotProductAttention(Module): key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` decode: whether to prepare and use an autoregressive cache. + normalize_qk: should QK normalization be applied (arxiv.org/abs/2302.05442). """ num_heads: int @@ -243,6 +241,7 @@ class MultiHeadDotProductAttention(Module): use_bias: bool = True attention_fn: Callable[..., Array] = dot_product_attention decode: bool = False + normalize_qk: bool = False qkv_dot_general: DotGeneralT = lax.dot_general out_dot_general: DotGeneralT = lax.dot_general @@ -260,17 +259,13 @@ def __call__( applies dot-product attention and project the results to an output vector. Args: - inputs_q: input queries of shape - `[batch_sizes..., length, features]`. - inputs_kv: key/values of shape - `[batch_sizes..., length, features]`. - mask: attention mask of shape - `[batch_sizes..., num_heads, query_length, key/value_length]`. - Attention weights are masked out if their corresponding mask value - is `False`. - deterministic: if false, the attention weight is masked randomly - using dropout, whereas if true, the attention weights - are deterministic. + inputs_q: input queries of shape `[batch_sizes..., length, features]`. + inputs_kv: key/values of shape `[batch_sizes..., length, features]`. + mask: attention mask of shape `[batch_sizes..., num_heads, query_length, + key/value_length]`. Attention weights are masked out if their + corresponding mask value is `False`. + deterministic: if false, the attention weight is masked randomly using + dropout, whereas if true, the attention weights are deterministic. Returns: output of shape `[batch_sizes..., length, features]`. @@ -303,6 +298,12 @@ def __call__( dense(name='value')(inputs_kv), ) + if self.normalize_qk: + # Normalizing query and key projections stabilizes training with higher + # LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis. + query = LayerNorm(name='query_ln', use_bias=False)(query) # type: ignore[call-arg] + key = LayerNorm(name='key_ln', use_bias=False)(key) # type: ignore[call-arg] + # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.decode: @@ -413,15 +414,12 @@ def __call__( # type: ignore applies dot-product attention and project the results to an output vector. Args: - inputs_q: input queries of shape - `[batch_sizes..., length, features]`. - mask: attention mask of shape - `[batch_sizes..., num_heads, query_length, key/value_length]`. - Attention weights are masked out if their corresponding mask value - is `False`. - deterministic: if false, the attention weight is masked randomly - using dropout, whereas if true, the attention weights - are deterministic. + inputs_q: input queries of shape `[batch_sizes..., length, features]`. + mask: attention mask of shape `[batch_sizes..., num_heads, query_length, + key/value_length]`. Attention weights are masked out if their + corresponding mask value is `False`. + deterministic: if false, the attention weight is masked randomly using + dropout, whereas if true, the attention weights are deterministic. Returns: output of shape `[batch_sizes..., length, features]`. @@ -451,8 +449,8 @@ def make_attention_mask( query_input: a batched, flat input of query_length size key_input: a batched, flat input of key_length size pairwise_fn: broadcasting elementwise comparison function - extra_batch_dims: number of extra batch dims to add singleton - axes for, none by default + extra_batch_dims: number of extra batch dims to add singleton axes for, none + by default dtype: mask return dtype Returns: @@ -477,8 +475,8 @@ def make_causal_mask( Args: x: input array of shape `[batch..., len]` - extra_batch_dims: number of batch dims to add singleton axes for, - none by default + extra_batch_dims: number of batch dims to add singleton axes for, none by + default dtype: mask return dtype Returns: