Skip to content

Commit

Permalink
[flax] Add QK-normalization to MultiHeadDotProductAttention
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 554507242
  • Loading branch information
prazek authored and Flax Authors committed Aug 12, 2023
1 parent c8424b0 commit 0e45cd3
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,25 @@
"""Attention core modules for Flax."""

import functools
from typing import (Any, Callable, Optional, Tuple, Union)
from flax.linen.dtypes import promote_dtype
from typing import Any, Callable, Optional, Tuple

from flax.linen import initializers
from flax.linen.dtypes import promote_dtype
from flax.linen.linear import default_kernel_init
from flax.linen.linear import DenseGeneral
from flax.linen.linear import DotGeneralT
from flax.linen.linear import PrecisionLike
from flax.linen.module import compact
from flax.linen.module import merge_param
from flax.linen.module import Module
from flax.linen.normalization import LayerNorm

import jax
from jax import lax
from jax import random
import jax.numpy as jnp


PRNGKey = Any
Shape = Tuple[int, ...]
Dtype = Any
Expand Down Expand Up @@ -225,6 +227,7 @@ class MultiHeadDotProductAttention(Module):
key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,,
num_heads, value_channels]``
decode: whether to prepare and use an autoregressive cache.
normalize_qk: should QK normalization be applied (arxiv.org/abs/2302.05442).
"""

num_heads: int
Expand All @@ -243,6 +246,7 @@ class MultiHeadDotProductAttention(Module):
use_bias: bool = True
attention_fn: Callable[..., Array] = dot_product_attention
decode: bool = False
normalize_qk: bool = False
qkv_dot_general: DotGeneralT = lax.dot_general
out_dot_general: DotGeneralT = lax.dot_general

Expand Down Expand Up @@ -303,6 +307,12 @@ def __call__(
dense(name='value')(inputs_kv),
)

if self.normalize_qk:
# Normalizing query and key projections stabilizes training with higher
# LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis.
query = LayerNorm(name='query_ln', use_bias=False)(query) # type: ignore[call-arg]
key = LayerNorm(name='key_ln', use_bias=False)(key) # type: ignore[call-arg]

# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.decode:
Expand Down

0 comments on commit 0e45cd3

Please sign in to comment.