Memoization of function input/output #23171
Unanswered
EdwardRaff
asked this question in
Q&A
Replies: 2 comments 3 replies
-
do you want to cache the compilation or lowering (stableHLO)? We have support for both. lowering: https://jax.readthedocs.io/en/latest/export/export.html compilation: https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html |
Beta Was this translation helpful? Give feedback.
0 replies
-
Neither I don't care about the compilation. The calculation itself is enormously expensive. e.g., I'd like to have something like: def expensive(x):
return x # but with a lot of work, like takes an hour to call
cache = {}
def memoized_expensive(x):
if x in cache:
return cache[x]
cache[x] = expensive(x)
return cache[x] But in a way where I can still call jit/grad/etc on |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
There appears to be some internal JAX usage of memorization in the compilation, which may be obfuscating what I'm searching for. I'm not worried about compile times in my use case, but I want to memorize the input/output pairs of some functions.
e.g., I have some function
y = expensive(x)
that will only ever be called with a finite number of values forx
and will be called many times. I'd love a way to memoize the functionexpensive
rather than the less ergonomic pre-calculating and looking up.Is there any clean way to do this in Jax that still supports gradients/auto-diff?
Beta Was this translation helpful? Give feedback.
All reactions