Running multiple jobs, one per TPU core? #16629
-
So I read that TPUs have multiple cores, it would be mighty handy for a project I am looking at to be able to dispatch jobs that might run for different lengths of time on separate TPU cores with different compiled code, that return their solutions when they are ready back to python. So far I have been Vmaping and pmaping the same algorithm and now have a need to switch it up. Is there any documentation on how to do this please? I would be using colab/kaggle initially, then moving onto Google Cloud single TPU(s), with a view to eventually needing a pod slice or two. Fine-grained control of how I can assign jobs to chips/cores is something I would need to address. Any advice will be gratefully received, thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 3 replies
-
I'd really appreciate your thoughts on this. Is it possible to run separate jobs on different TPU cores? Could you please point me at any relevant resources? As I say, presently I am using colab/kaggle for testing. Right now I have a Pmap'd Vmap of the same problem and it works very well, I would very much like to be able to run two separate jobs that are unrelated to each other on the same chip. Thanks in advance for any response :-) |
Beta Was this translation helpful? Give feedback.
-
Solution: https://twitter.com/ayaka14732/status/1589274652354162690 I am also creating a library about this: https://github.com/ayaka14732/llama-jax/blob/main/lib/proc_init_utils/initialisation.py Usage:
from lib.proc_init_utils import initialise_tpu; initialise_tpu('v4-16', n_devices=1, rank=0)
from lib.proc_init_utils import initialise_tpu; initialise_tpu('v4-16', n_devices=1, rank=1)
from lib.proc_init_utils import initialise_tpu; initialise_tpu('v4-16', n_devices=1, rank=2)
from lib.proc_init_utils import initialise_tpu; initialise_tpu('v4-16', n_devices=1, rank=3) |
Beta Was this translation helpful? Give feedback.
-
Sorry to revive an old thread. Is there a way to do something like this to run a process per core instead of a process per chip? I'm thinking of the v5p architecture which has 2 cores per chip. Wondering how to support full use of this architecture in jax. |
Beta Was this translation helpful? Give feedback.
Solution: https://twitter.com/ayaka14732/status/1589274652354162690
I am also creating a library about this: https://github.com/ayaka14732/llama-jax/blob/main/lib/proc_init_utils/initialisation.py
Usage:
1.py
2.py
3.py
4.py