Replies: 1 comment 1 reply
-
I'm running into the same problem. Do you think you can at least make batches of equally shaped inputs? If so this seems easier to answer (the easy answer might be "make batches out of data with same size and alternate them on training", but care about your loss function, it also needs to be compatible for all shapes, i.e. its results are in the same order of magnitude). As far as I know, |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have a loss function depends on the outputs of, says, 100s different networks with different (relatively small) size:
Each networks operates on a varying number of samples.
Up till now I have concatenate all my inputs and
vmap
all my function to operate on the largest possible regular shape matrix.However, since the loss term depends on the outputs of all the networks, and the networks are of different size, I don't know any other ways to parallelize or vectorize the loss anymore, and jitting the the entire loss loop as below are just too slow:
Ultimately, I would like to vmap or pmap over all networks instead of looping over them all like this:
I have read / skim every single documetation's page over vmap, pmap, jax.lax.fori, jax.lax.while_loop etcetera but I don't think there is any way to achieve this. So I'm posting here in hope of somebody point out that I'm wrong.
Is my best hope of speeding this up even more is to use traditional multiprocessing paralellization in Python to parallelize the networks loop in multiple CPUs threads, since the networks are quite small to be benefited from GPU and TPU anyway?
Any though is appreciated. Thank you.
Beta Was this translation helpful? Give feedback.
All reactions