Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support unpacked FP4E2M1 #164

Closed
justinchuby opened this issue Aug 7, 2024 · 5 comments
Closed

Support unpacked FP4E2M1 #164

justinchuby opened this issue Aug 7, 2024 · 5 comments

Comments

@justinchuby
Copy link

Creating this tread from #116 for a focused proposal on supporting FP4E2M1. Thanks!

@balancap
Copy link
Contributor

balancap commented Aug 9, 2024

From my experience implementing #166 , I believe the present float8_base can be adapted without much difficulty to support FP4 (and FP6) dtypes.

I would follow the suggestion of @cloudhan, adding a type traits sizeof_bits specialized for FP4/FP6 formats to give the proper bit size of every dtype. Then using sizeof_bits, TraitsBase can be extended to have proper bitmasking for the exponent and mantissa parts (https://github.com/jax-ml/ml_dtypes/blob/main/ml_dtypes/include/float8.h#L891). The rest should almost work out of the box :)

@hawkinsp
Copy link
Collaborator

#181 did this!

@justinchuby
Copy link
Author

@hawkinsp thanks! Do you know when it will be released?

@hawkinsp
Copy link
Collaborator

Today, hopefully.

@justinchuby
Copy link
Author

That’s amazing, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants