We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Suppose we wish to export a Dex function that takes a Nat argument to Jax:
Nat
dex_iota = primitive(dex.eval(r"\(size:Nat). for i:(Fin size). ordinal i"))
If we just call it, it works ok:
dex_iota(5) > [0, 1, 2, 3, 4]
But if we jit it first, it shows us a type error:
jit
jax.jit(dex_iota)(5) E RuntimeError: dtype mismatch in arg 0: expected uint32, got int32
We should probably pick one of these behaviors and stick with it (though it issue may have to do with Jax's notion of weak types).
The text was updated successfully, but these errors were encountered:
Yeah weak types sound relevant here!
Sorry, something went wrong.
I assume what we actually want is to behave like any other Jax primitive. Do we know clearly enough what behavior that is?
Not sure... We should check in with others. Jake will know for sure.
No branches or pull requests
Suppose we wish to export a Dex function that takes a
Nat
argument to Jax:If we just call it, it works ok:
But if we
jit
it first, it shows us a type error:We should probably pick one of these behaviors and stick with it (though it issue may have to do with Jax's notion of weak types).
The text was updated successfully, but these errors were encountered: