Skip to content

How to use init_cache() in NNX MHA correctly? #4290

Closed Answered by cgarciae
windmaple asked this question in General
Discussion options

You must be logged in to vote

Hey @windmaple, when in decode mode only a single token must be passed at a time e.g:

batch_size = 2
seqlen = 40
emb_size = 256

x = jnp.ones((batch_size, seqlen, emb_size))

mha = nnx.MultiHeadAttention(
  in_features=emb_size, num_heads=2, decode=True, rngs=nnx.Rngs(0)
)
shape = x.shape

mha.init_cache((x.shape[0], x.shape[1], x.shape[-1]), dtype=x.dtype)
for i in range(seqlen): # iterate all tokens
  y = mha(x[:, i : i + 1])
print('success')

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@cgarciae
Comment options

@windmaple
Comment options

Answer selected by windmaple
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants