Replies: 3 comments 2 replies
-
No, unfortunately there's no way in JAX's current release to pickle jaxprs. The issue is that jaxprs are built on JAX primitives, each of which has associated impl and abstract_eval rules, many of which are defined using local/anonymous functions that are not compatible with pickle. Because of that, making jaxprs fully compatible with pickle would require some level of rewriting of JAX's internal code. |
Beta Was this translation helpful? Give feedback.
-
It might be possible to use something like dill instead of pickle. |
Beta Was this translation helpful? Give feedback.
-
See this issue here for a related dicussion |
Beta Was this translation helpful? Give feedback.
-
Hi folks,
I'm generating jaxpr for the forward pass of a LLM. I'm using the generated jaxpr for some downstream processing.
Generating jaxpr takes ~10 mins and I'd like to save / cache this result to disk to save this time.
But, when I try to do a
pkl.dump(jaxpr...)
, I get the following error:Is there any way around this?
Alternatively, what's the recommended way to pickle jaxprs?
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions