Skip to content

Rehydrating a Jaxpr from HLO / proto? #14147

Answered by mattjj
jkr26 asked this question in Q&A
Jan 25, 2023 · 1 comments · 2 replies
Discussion options

You must be logged in to vote

Thanks for the question!

You can write an HLO interpreter which walks the HLO program and then evaluates each HLO operation in terms of jax.numpy, or even easier jax.lax (since that's almost 1:1 with HLO). Then you can jax.jit it, apply jax.grad to it. It'd look kind of like onnx2xla.py.

Actually, we wrote something like this years ago. I'll try to find it...

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@mattjj
Comment options

@jkr26
Comment options

Answer selected by jkr26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants