-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BUG: Conversion of Pytensor graph to jax function renames input variables #1144
Comments
Call by keyword argument is not a provided feature of jaxified graphs (and that utility is not from pytensor). PyTensor function Note that input names are optional and can be repeated. Also inputs can be provided in an arbitrary order regardless of whether they have names or not. Also is this happening because of PyTensor or JAX? |
The utility is just a thin wrapper around |
The utility f_loss_jax = pytensor.function(model.value_vars, outputs=[model.logp()], mode="JAX") What's the difference between this and what I think we tried it Jesse, that |
all good now, appreciate the help guys |
That returns a jax function that you can jit/vmap/grad, whereas PyTensor function will wrap a Jitted JAX function that is no longer composable. It handles stuff like keyword argument / shared variables / updates that are PyTensor specific |
Describe the issue:
Not sure if this is a bug, but @jessegrabowski and I noticed recently that when a Pytensor graph is converted into a jax function using PyMC's
get_jaxified_graph
function, the input variables get renamed in a subtle way.Trailing double underscores on variable names get clipped to a single trailing underscore, and dashes get converted to underscores. There might be other changes, but just noticed these two.
Is this a bug? And if not, where in the code does the renaming occur?
Reproducable code example:
Error message:
PyTensor version information:
'2.26.3'
Context for the issue:
We're working on a minibatched SGD implementation for a GP approximation, and the last piece is inverting this variable renaming.
The text was updated successfully, but these errors were encountered: