Replies: 1 comment 1 reply
-
Hi @Edenhofer, we encountered a similar problem in a package we developed and ended up implementing a They are all implemented on top of You can try to download |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
For a reconstruction in astrophysics, I need to perform
n_los
line-of-sight (LOS) integrals within my model. The integrals boil down to weighted sums along one axis of the input. The weights are computed using non-linear functions of some constants, but the LOS integrals themselves are linear in the input. In other words, the LOS integrals can be though of as a very large implicit matrix. Unfortunately, due to memory constraints, it is no possible to batch all LOS integrals in parallel. Instead, I would like to batch through the integrals in chunks.I have two main constraints for computing the LOS integrals, (1) they need to be fast and (2) they need to be memory efficient. Both constraints seem pretty standard but require some extra though especially when differentiating through the model. Naively, the backwards pass would e.g. store all intermediate LOS weights instead of rematerializing them on the fly. Furthermore, because different LOS integrals access the same elements in an array, the backwards pass has a lot of atomic adds.
My current approach to fulfill both constraints involves manually batching the
n_los
LOS integrals and defining a custom JVP and VJP (viajax.custom_jvp
andjax.custom_derivatives.linear_call
). There are multiple downsides to this approach, including that the model gets un-batchable after usinglinear_call
and second-derivatives become inaccessible. (Hopefully all of this will get better with #9073 and #9129.)To get rid of all of my problems, simplify the code and use more experimental JAX features (:D), I decided to give
xmap
andSerialLoop
a try. Unfortunately, I can't get them to work for my problem. I would be glad if someone could help me getSerialLoop
intoaxis_resources
and guide me through how I should tell JAX about my desired rematerialization strategy.Beta Was this translation helpful? Give feedback.
All reactions