-
Hi JAX community! Suppose that I can get my hands on an HLO module (proto) representing another author's serialized code, along with any appropriate metadata I might require. Suppose that it is also my desire to 'rehydrate' this logic and use with Jax's function transformations (like autodiff, or even more simply Python function composition). I can find plenty of examples of going from Python to Jaxpr to XLA / HLO to compiled chunks; is there a known way to go the other way around? If not, is there an alternate serialization format that would make such a thing trivially possible? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Thanks for the question! You can write an HLO interpreter which walks the HLO program and then evaluates each HLO operation in terms of Actually, we wrote something like this years ago. I'll try to find it... |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
You can write an HLO interpreter which walks the HLO program and then evaluates each HLO operation in terms of
jax.numpy
, or even easierjax.lax
(since that's almost 1:1 with HLO). Then you canjax.jit
it, applyjax.grad
to it. It'd look kind of like onnx2xla.py.Actually, we wrote something like this years ago. I'll try to find it...