-
Notifications
You must be signed in to change notification settings - Fork 283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Switch to beartype #325
Switch to beartype #325
Conversation
@jbloomAus can you rerun the CI? think it failed due to an unrelated issue with connecting to huggingface |
@jbloomAus thanks! this is ready for review. Feel free to change the |
Thanks David, will review in the next couple of days :) |
@dkamm, looks great! I just want to double check/compare equivalent error messages and also have a summary here of why we're switching IRC, it's because we had testing issues and then install issues with typeguard? Feels like a long story but you have most context. Once that's written, I'll share with Neel just to be sure he's cool with it. |
We're switching because there have been issues getting the newer versions of typeguard to work on this codebase for various reasons. Typeguard 3 - have to add In contrast, the latest version of beartype works with the small number changes described above. None of these changes preclude switching to typeguard 4 when that issue is resolved, but it's not a trivial issue to solve. I figured this solution is the best option right now. |
Thanks @dkamm! @alan-cooney This seems like a good thing to me, are you able to verify/double check please? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR - nice work!
Just to check - have you checked that providing an incorrect jaxtyping type causes beartype to throw an error (when running pytest)?
Otherwise just one question to confirm in the review - where I think I'm missing something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One small thing to check - otherwise good to go.
Thanks again for this!
Alan
@dkamm, just pinging you to see if you've had a chance to action @alan-cooney's suggestions. |
Just added it! Sorry I missed it earlier |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry - it's out of sync actually with the latest changes to master. Are you able to update and then ping me and I'll get it merged right away?
self, | ||
q: Float[torch.Tensor, "batch pos head_index d_head"], | ||
k: Float[torch.Tensor, "batch pos head_index d_head"], | ||
q: Float[torch.Tensor, "batch q_pos head_index d_head"], | ||
k: Float[torch.Tensor, "batch k_pos head_index d_head"], | ||
past_kv_pos_offset, | ||
) -> Tuple[ | ||
Float[torch.Tensor, "batch pos head_index d_head"], | ||
Float[torch.Tensor, "batch pos head_index d_head"], | ||
Float[torch.Tensor, "batch q_pos head_index d_head"], | ||
Float[torch.Tensor, "batch k_pos head_index d_head"], | ||
]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alan-cooney @jbloomAus I had to change this typehint to get one of the tests to pass (test_hooked_transformer::test_dtypes). It failed because this method got passed q
, k
with different pos
dimensions and jaxtyping checks for that as configured. I'm not familiar enough with this method and its usage to say whether q
, k
should be allowed to have different pos
dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure it's actually getting passed different sized queries and keys, or are the defined type signatures at this point just different?
I think this may be an issue with Beartype getting confused about the size of one of these tensors at a previous point (perhaps in the if past_kv_cache_entry is not None:
section)?
I'm not very familiar with rotary embeddings either but as far as I can tell the q and k passed in here should be the same size (and they should also be the same size afterwards as they're immediately dot producted).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure it's actually getting passed different sized queries and keys, or are the defined type signatures at this point just different?
Yep, I confirmed that they were different sizes.
I think this may be an issue with Beartype getting confused about the size of one of these tensors at a previous point (perhaps in the if past_kv_cache_entry is not None: section)?
Hmm, I don't think this is the case. Beartype/Jaxtyping should only look at the arguments passed to the call in this case
Here's a minimal script to reproduce (you'll have to change the typehints back):
from jaxtyping import install_import_hook
with install_import_hook("transformer_lens", "beartype.beartype"):
from transformer_lens import HookedTransformer
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m")
_ = model.generate("Hello, World!")
Actually, reading the code more, I think my change is correct and we should allow different sizes in the pos
dim. The reason is that when we're generating completion, after we feed through the initial input, we're using only the latest token as the query vec and using the kv cache for the key vecs to compute attention scores. It's ok for q
and k
to have different sizes in the pos
dim because they are getting dotted over the d_head
dim. Finally, the past_kv_pos_offset
is just so we get the right positional embedding for q
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense thanks!
@alan-cooney - I've updated the branch with latest changes from main. There's one new comment that requires review |
Thanks for the PR - merged now! |
Description
Switching to beartype from typeguard due to incompatibility with jaxtyping on the latest version
Notes:
Union["FactoredMatrix", ...]
Type of change
Please delete options that are not relevant.
Screenshots
N/A
Checklist: