Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to beartype #325

Merged
merged 12 commits into from
Aug 3, 2023
Merged

Switch to beartype #325

merged 12 commits into from
Aug 3, 2023

Conversation

dkamm
Copy link
Contributor

@dkamm dkamm commented Jun 16, 2023

Description

Switching to beartype from typeguard due to incompatibility with jaxtyping on the latest version

Notes:

  • typeguard is still in the lockfile because jaxtyping requires it
  • had to use forward references in these spots to get beartype to work
    • devices.py - this is a circular import solely due to typehints
    • utils.py - this is a true circular import that we might want to remove
    • FactoredMatrix.py - had to use it for methods returning Union["FactoredMatrix", ...]

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Screenshots

N/A

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@dkamm
Copy link
Contributor Author

dkamm commented Jun 16, 2023

@jbloomAus can you rerun the CI? think it failed due to an unrelated issue with connecting to huggingface

@dkamm
Copy link
Contributor Author

dkamm commented Jun 17, 2023

@jbloomAus thanks! this is ready for review. Feel free to change the composition_scores typehint as I'm not sure how it works

@jbloomAus
Copy link
Collaborator

Thanks David, will review in the next couple of days :)

@jbloomAus
Copy link
Collaborator

@dkamm, looks great! I just want to double check/compare equivalent error messages and also have a summary here of why we're switching IRC, it's because we had testing issues and then install issues with typeguard? Feels like a long story but you have most context. Once that's written, I'll share with Neel just to be sure he's cool with it.

@jbloomAus jbloomAus added the seen_by_maintainers Confirms that a maintainer is aware of this card. label Jun 20, 2023
@dkamm
Copy link
Contributor Author

dkamm commented Jun 28, 2023

@jbloomAus

We're switching because there have been issues getting the newer versions of typeguard to work on this codebase for various reasons.

Typeguard 3 - have to add @typeguard_ignore to class properties, doesn't work with the jaxtyping pytest plugin
Typeguard 4 - incompatible with jaxtyping due to how it treats annotations as forward references in its ast transformation

In contrast, the latest version of beartype works with the small number changes described above.

None of these changes preclude switching to typeguard 4 when that issue is resolved, but it's not a trivial issue to solve. I figured this solution is the best option right now.

@jbloomAus
Copy link
Collaborator

Thanks @dkamm!

@alan-cooney This seems like a good thing to me, are you able to verify/double check please?

Copy link
Collaborator

@alan-cooney alan-cooney left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR - nice work!

Just to check - have you checked that providing an incorrect jaxtyping type causes beartype to throw an error (when running pytest)?

Otherwise just one question to confirm in the review - where I think I'm missing something.

transformer_lens/FactoredMatrix.py Show resolved Hide resolved
Copy link
Collaborator

@alan-cooney alan-cooney left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small thing to check - otherwise good to go.

Thanks again for this!
Alan

transformer_lens/FactoredMatrix.py Show resolved Hide resolved
transformer_lens/utils.py Show resolved Hide resolved
@jbloomAus
Copy link
Collaborator

@dkamm, just pinging you to see if you've had a chance to action @alan-cooney's suggestions.

@dkamm
Copy link
Contributor Author

dkamm commented Jul 26, 2023

Just added it! Sorry I missed it earlier

alan-cooney

This comment was marked as outdated.

Copy link
Collaborator

@alan-cooney alan-cooney left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry - it's out of sync actually with the latest changes to master. Are you able to update and then ping me and I'll get it merged right away?

Comment on lines 639 to 646
self,
q: Float[torch.Tensor, "batch pos head_index d_head"],
k: Float[torch.Tensor, "batch pos head_index d_head"],
q: Float[torch.Tensor, "batch q_pos head_index d_head"],
k: Float[torch.Tensor, "batch k_pos head_index d_head"],
past_kv_pos_offset,
) -> Tuple[
Float[torch.Tensor, "batch pos head_index d_head"],
Float[torch.Tensor, "batch pos head_index d_head"],
Float[torch.Tensor, "batch q_pos head_index d_head"],
Float[torch.Tensor, "batch k_pos head_index d_head"],
]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alan-cooney @jbloomAus I had to change this typehint to get one of the tests to pass (test_hooked_transformer::test_dtypes). It failed because this method got passed q, k with different pos dimensions and jaxtyping checks for that as configured. I'm not familiar enough with this method and its usage to say whether q, k should be allowed to have different pos dimensions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure it's actually getting passed different sized queries and keys, or are the defined type signatures at this point just different?

I think this may be an issue with Beartype getting confused about the size of one of these tensors at a previous point (perhaps in the if past_kv_cache_entry is not None: section)?

I'm not very familiar with rotary embeddings either but as far as I can tell the q and k passed in here should be the same size (and they should also be the same size afterwards as they're immediately dot producted).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure it's actually getting passed different sized queries and keys, or are the defined type signatures at this point just different?

Yep, I confirmed that they were different sizes.

I think this may be an issue with Beartype getting confused about the size of one of these tensors at a previous point (perhaps in the if past_kv_cache_entry is not None: section)?

Hmm, I don't think this is the case. Beartype/Jaxtyping should only look at the arguments passed to the call in this case

Here's a minimal script to reproduce (you'll have to change the typehints back):

from jaxtyping import install_import_hook
with install_import_hook("transformer_lens", "beartype.beartype"):
    from transformer_lens import HookedTransformer
    model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m")
    _ = model.generate("Hello, World!")

Actually, reading the code more, I think my change is correct and we should allow different sizes in the pos dim. The reason is that when we're generating completion, after we feed through the initial input, we're using only the latest token as the query vec and using the kv cache for the key vecs to compute attention scores. It's ok for q and k to have different sizes in the pos dim because they are getting dotted over the d_head dim. Finally, the past_kv_pos_offset is just so we get the right positional embedding for q.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense thanks!

@dkamm
Copy link
Contributor Author

dkamm commented Aug 1, 2023

@alan-cooney - I've updated the branch with latest changes from main. There's one new comment that requires review

@alan-cooney alan-cooney merged commit 10d2f8a into TransformerLensOrg:main Aug 3, 2023
4 checks passed
@alan-cooney
Copy link
Collaborator

Thanks for the PR - merged now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
seen_by_maintainers Confirms that a maintainer is aware of this card.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants