-
So I'm trying to enable the KV cache. Here is my code: ` batch_size=2 x = jnp.ones((batch_size, seqlen, emb_size)) mha = nnx.MultiHeadAttention( mha.init_cache((x.shape[0], x.shape[1], x.shape[-1]), dtype=x.dtype) I got this error:
It seems 'seqlen' dimension is not being used. Looking at the init_cache() code, it looks like it should be. The docstring only uses a 2 dimensional array as an example, so I can't quite figure this one out. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Hey @windmaple, when in batch_size = 2
seqlen = 40
emb_size = 256
x = jnp.ones((batch_size, seqlen, emb_size))
mha = nnx.MultiHeadAttention(
in_features=emb_size, num_heads=2, decode=True, rngs=nnx.Rngs(0)
)
shape = x.shape
mha.init_cache((x.shape[0], x.shape[1], x.shape[-1]), dtype=x.dtype)
for i in range(seqlen): # iterate all tokens
y = mha(x[:, i : i + 1])
print('success') |
Beta Was this translation helpful? Give feedback.
Hey @windmaple, when in
decode
mode only a single token must be passed at a time e.g: