Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potentially wrong type inference #1188

Open
yongchanghao opened this issue Mar 6, 2023 · 0 comments
Open

Potentially wrong type inference #1188

yongchanghao opened this issue Mar 6, 2023 · 0 comments

Comments

@yongchanghao
Copy link
Contributor

The doc says mu is inferred from grads and updates if mu_dtype=None.

But this line actually turns jnp.bfloat16 and jnp.float16 to jnp.float32 when mu_dtype=None.

Example on GPUs:

>>> jax.__version__
'0.4.4'
>>> x.astype(jnp.float16).dtype
dtype('float16')
>>> x.astype(jnp.float16).astype(None).dtype
dtype('float32')
@yongchanghao yongchanghao changed the title Potential wrong type inference Potentially wrong type inference Mar 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant