Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Conversion of Pytensor graph to jax function renames input variables #1144

Closed
bwengals opened this issue Dec 31, 2024 · 5 comments
Closed
Labels
bug Something isn't working jax needs info Additional information required

Comments

@bwengals
Copy link

Describe the issue:

Not sure if this is a bug, but @jessegrabowski and I noticed recently that when a Pytensor graph is converted into a jax function using PyMC's get_jaxified_graph function, the input variables get renamed in a subtle way.

Trailing double underscores on variable names get clipped to a single trailing underscore, and dashes get converted to underscores. There might be other changes, but just noticed these two.

Is this a bug? And if not, where in the code does the renaming occur?

Reproducable code example:

import pymc as pm

with pm.Model() as model:
    x = pm.Flat("x-x__")
    mu = pm.Lognormal("mu") + x
    sigma = pm.HalfNormal("sigma-sigma")
    pm.Normal("y", mu=mu, sigma=sigma, observed=[1.0, 2.0, 3.0])

f_loss_jax = get_jaxified_graph(model.value_vars, [model.logp()])
f_loss_jax

Error message:

<function pytensor.link.utils.jax_funcified_fgraph(x_x_, mu_log_, sigma_sigma_log_)>

PyTensor version information:

'2.26.3'

Context for the issue:

We're working on a minibatched SGD implementation for a GP approximation, and the last piece is inverting this variable renaming.

@bwengals bwengals added the bug Something isn't working label Dec 31, 2024
@ricardoV94
Copy link
Member

ricardoV94 commented Dec 31, 2024

Call by keyword argument is not a provided feature of jaxified graphs (and that utility is not from pytensor).

PyTensor function __call__ does the keyword to positional argument mapping. If you sidestep the Pytensor function you have to do the mapping yourself.

Note that input names are optional and can be repeated. Also inputs can be provided in an arbitrary order regardless of whether they have names or not.

Also is this happening because of PyTensor or JAX?

@ricardoV94 ricardoV94 added jax needs info Additional information required bug Something isn't working and removed bug Something isn't working labels Dec 31, 2024
@jessegrabowski
Copy link
Member

The utility is just a thin wrapper around pytensor.function anyway. We should test if we can get away with just doing that directly.

@bwengals
Copy link
Author

bwengals commented Jan 2, 2025

The utility get_jaxified_graph is a thin wrapper around pytensor.function? Meaning:

f_loss_jax = pytensor.function(model.value_vars, outputs=[model.logp()], mode="JAX")

What's the difference between this and what get_jaxified_graph returns?

I think we tried it Jesse, that pytensor.function call above is commented out in our code right below

@bwengals
Copy link
Author

bwengals commented Jan 4, 2025

all good now, appreciate the help guys

@bwengals bwengals closed this as completed Jan 4, 2025
@ricardoV94
Copy link
Member

What's the difference between this and what get_jaxified_graph returns?

That returns a jax function that you can jit/vmap/grad, whereas PyTensor function will wrap a Jitted JAX function that is no longer composable. It handles stuff like keyword argument / shared variables / updates that are PyTensor specific

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working jax needs info Additional information required
Projects
None yet
Development

No branches or pull requests

3 participants