Multi-host parallelization with GPUs #18659
Unanswered
IrishWhiskey
asked this question in
Q&A
Replies: 2 comments 1 reply
-
You need the EFA drivers and certain magic bits for the stack to work properly in AWS. the jax container from jax-toolbox should work for AWS cc: @yhtang |
Beta Was this translation helpful? Give feedback.
1 reply
-
I have the same question. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello,
I'm trying the JAX multi-host parallelization on AWS GPU instances but I can't get it working. What am I doing wrong? See below the process I followed.
I created two EC2 instances (
p3.2xlarge
) and set up the network configurations so that they can communicate with each other. After installing the required dependencies, I ran the following script on both instances:The script is not 100% same as in the second instance I set the
process_id
to1
(the coordinator address refers to the first instance).From the terminal output I can see that up until the
print(xs)
everything works fine but the last line causes an error. Below are the outputs I got from both instances.First instance:
Second instance:
I'm using python
3.10.12
, jax0.4.20
and jaxlib0.4.20+cuda11.cudnn86
.Beta Was this translation helpful? Give feedback.
All reactions