Skip to content

Commit

Permalink
Make force_fp32_for_softmax arg in MultiHeadDotProductAttention u…
Browse files Browse the repository at this point in the history
…seful.

Fixes #4008

PiperOrigin-RevId: 646679331
  • Loading branch information
IvyZX authored and Flax Authors committed Jun 26, 2024
1 parent 3b21870 commit 9e6b819
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 29 deletions.
46 changes: 21 additions & 25 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

import functools
import inspect
import warnings
from typing import Any, Callable, Optional, Union, overload

Expand Down Expand Up @@ -574,33 +575,28 @@ def __call__(
m_deterministic = True

# apply attention
attn_args = (query, key, value)
# This kwargs list match the default nn.dot_product_attention.
# For custom `attention_fn`s, invalid kwargs will be filtered.
attn_kwargs = dict(
mask=mask,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=m_deterministic,
dtype=self.dtype,
precision=self.precision,
force_fp32_for_softmax=self.force_fp32_for_softmax,
)
attn_kwargs = {
k: v
for k, v in attn_kwargs.items()
if k in inspect.signature(self.attention_fn).parameters
}
if sow_weights:
x = self.attention_fn(
query,
key,
value,
mask=mask,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=m_deterministic,
dtype=self.dtype,
precision=self.precision,
module=self,
) # pytype: disable=wrong-keyword-args
x = self.attention_fn(*attn_args, **attn_kwargs, module=self)
else:
x = self.attention_fn(
query,
key,
value,
mask=mask,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=m_deterministic,
dtype=self.dtype,
precision=self.precision,
)
x = self.attention_fn(*attn_args, **attn_kwargs)
# back to the original inputs dimensions
out = DenseGeneral(
features=features,
Expand Down
5 changes: 1 addition & 4 deletions tests/linen/linen_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

"""Tests for flax.linen.attention."""

import functools
from absl.testing import absltest, parameterized
from flax import errors, jax_utils
from flax import linen as nn
Expand Down Expand Up @@ -565,9 +564,7 @@ def test_mixed_precision_multihead_attention(
qkv_features=4,
kernel_init=initializers.lecun_normal(),
bias_init=initializers.uniform(),
attention_fn=functools.partial(
nn.dot_product_attention, force_fp32_for_softmax=force_fp32
),
force_fp32_for_softmax=force_fp32,
deterministic=False,
dtype=jnp.bfloat16,
)
Expand Down

0 comments on commit 9e6b819

Please sign in to comment.