Skip to content

Commit

Permalink
Merge pull request #4291 from google:init-cache-docs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689084386
  • Loading branch information
Flax Authors committed Oct 23, 2024
2 parents 3d9c7e0 + 363f3df commit 3f3c03b
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions flax/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,23 +569,25 @@ def __call__(
def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32):
"""Initializes cache for fast autoregressive decoding. When
``decode=True``, this method must be called first before performing
forward inference.
forward inference. When in decode mode, only one token must be passed
at a time.
Example usage::
>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> rngs = nnx.Rngs(42)
>>> batch_size = 5
>>> embed_dim = 3
>>> x = jnp.ones((batch_size, 1, embed_dim)) # single token
...
>>> x = jnp.ones((1, 3))
>>> model_nnx = nnx.MultiHeadAttention(
... num_heads=2,
... in_features=3,
... qkv_features=6,
... out_features=6,
... decode=True,
... rngs=rngs,
... rngs=nnx.Rngs(42),
... )
...
>>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized
Expand Down

0 comments on commit 3f3c03b

Please sign in to comment.