Implementing L2L in JAX #5871
Unanswered
davisyoshida
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm interested in implementing a decorator to swap parameters on and off of the GPU as described here: https://arxiv.org/pdf/2002.05645.pdf
To get started, I made the following decorator:
Unfortunately, functions decorated by this cannot be JIT-ed, since they take arguments living on both the CPU and GPU. Is there a way I can do something like this (or some other method entirely) and still make use of
jax.jit
?Beta Was this translation helpful? Give feedback.
All reactions