Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling of Nat at the Jax-Dex boundary #1046

Open
axch opened this issue Sep 2, 2022 · 3 comments
Open

Handling of Nat at the Jax-Dex boundary #1046

axch opened this issue Sep 2, 2022 · 3 comments

Comments

@axch
Copy link
Collaborator

axch commented Sep 2, 2022

Suppose we wish to export a Dex function that takes a Nat argument to Jax:

dex_iota = primitive(dex.eval(r"\(size:Nat). for i:(Fin size). ordinal i"))

If we just call it, it works ok:

dex_iota(5)
> [0, 1, 2, 3, 4]

But if we jit it first, it shows us a type error:

jax.jit(dex_iota)(5)
E         RuntimeError: dtype mismatch in arg 0: expected uint32, got int32

We should probably pick one of these behaviors and stick with it (though it issue may have to do with Jax's notion of weak types).

@apaszke
Copy link
Collaborator

apaszke commented Sep 2, 2022

Yeah weak types sound relevant here!

@axch
Copy link
Collaborator Author

axch commented Sep 2, 2022

I assume what we actually want is to behave like any other Jax primitive. Do we know clearly enough what behavior that is?

@apaszke
Copy link
Collaborator

apaszke commented Sep 3, 2022

Not sure... We should check in with others. Jake will know for sure.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants