-
I have a training loop that I run several times in a row as part of an active learning strategy (every iteration, I optimize some vector which tells me where to sample next in the dataset) using optax. However, I'm finding that the training loop runs much faster in iterations after the first one. I'm wondering why this is (seems like some kind of memory allocation thing), and if I can somehow exploit it to make even the first iteration run faster. The training loop is pretty simple:
Later iterations of this code run at more than twice the speed of the first iteration, and I'm designing a real-time experiment so speed is key. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
likely it's because on the first iteration JAX is compiling parts of the computation, while the subsequent times it's caching that compilation |
Beta Was this translation helpful? Give feedback.
likely it's because on the first iteration JAX is compiling parts of the computation, while the subsequent times it's caching that compilation