You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
i was wondering if there is a more efficient approach to dealing with the mask in dot_product_attention.
Recently, I've been motivated to implement a cross-attention block that can have a very sparse mask for queries or keys (i.e., 100 locations out of the sequence with length 10000 are set as True while the rest are False) based on flax. Though the mask is supported in dot_product_attention, i found that dealing with the mask in dot_product_attention is very straightforward as it simply sets the weights of those False to big_neg as:
if mask is not None:
big_neg = jnp.finfo(dtype).min
attn_weights = jnp.where(mask, attn_weights, big_neg)
it simply screens out those with a False mask, while this provides the direct utility for jitting, to my shallow understanding, this does not reduce the memory and (especially) computation complexity of the attention, this is rather limiting usage especially when dealing with long sequences with a sparse mask.
i cannot reduce the maximum length here as a temp workaround, and it is believed that this problem is rather common in meta-learnings, thus may i ask if there are some insights to deal with such a use case, or this is a generic problem when using masks in attention block in Jax?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
hi community!
i was wondering if there is a more efficient approach to dealing with the
mask
indot_product_attention
.Recently, I've been motivated to implement a cross-attention block that can have a very sparse mask for queries or keys (i.e., 100 locations out of the sequence with length 10000 are set as
True
while the rest areFalse
) based on flax. Though the mask is supported indot_product_attention
, i found that dealing with the mask indot_product_attention
is very straightforward as it simply sets the weights of thoseFalse
tobig_neg
as:it simply screens out those with a
False
mask, while this provides the direct utility for jitting, to my shallow understanding, this does not reduce the memory and (especially) computation complexity of the attention, this is rather limiting usage especially when dealing with long sequences with a sparse mask.i cannot reduce the maximum length here as a temp workaround, and it is believed that this problem is rather common in meta-learnings, thus may i ask if there are some insights to deal with such a use case, or this is a generic problem when using masks in attention block in Jax?
Beta Was this translation helpful? Give feedback.
All reactions