Autodifferentiation through parallelized operators with xmap #14982
Unanswered
cmunna0052
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
This question concerns the same basic setup as my previous one #14879, but with a slightly different approach to sharding that gets a bit further before breaking. I am still trying to shard a feed-forward network on MNIST dataset by splitting up the weight matrices into 10 groups of columns. However, now I have defined an xmapped matrix multiplication operation on its own, with the following code:
This code works (though I am not sure why I don't have to recombine the x vector at each step with jax.lax.all_gather, and in fact doing so causes an error).
Now the problem comes at the next step, where I try to backpropagate with the following:
This correctly calculates the loss but fails in the train_batch portion with the error
assert len(arg) == n, f'length mismatch: [6,6,2]'
. I added in the custom_vjp because the regular jax.lax.pmean was throwing a similar error, and I assumed it wasn't a differentiable operator anyway. Any ideas on how I should get through this?Beta Was this translation helpful? Give feedback.
All reactions