Skip to content

Commit

Permalink
Add JAXJob output
Browse files Browse the repository at this point in the history
Signed-off-by: Sandipan Panda <samparksandipan@gmail.com>
  • Loading branch information
sandipanpanda committed Dec 18, 2024
1 parent 6fba8a6 commit e04bcf9
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish-example-images.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
platforms: linux/amd64
dockerfile: examples/pytorch/deepspeed-demo/Dockerfile
context: examples/pytorch/deepspeed-demo
- component-name: jaxjob-mnist
- component-name: jaxjob-dist-spmd-mnist
platforms: linux/amd64,linux/arm64
dockerfile: examples/jax/jax-dist-spmd-mnist/Dockerfile
context: examples/jax/jax-dist-spmd-mnist/
4 changes: 3 additions & 1 deletion examples/jax/jax-dist-spmd-mnist/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.12
FROM python:3.13

RUN pip install --upgrade pip
RUN pip install --upgrade jax absl-py
Expand All @@ -25,3 +25,5 @@ RUN git clone https://github.com/facebookincubator/gloo.git \
WORKDIR /app

ADD datasets.py spmd_mnist_classifier_fromscratch.py /app

ENTRYPOINT ["python3", "spmd_mnist_classifier_fromscratch.py"]
103 changes: 103 additions & 0 deletions examples/jax/jax-dist-spmd-mnist/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,118 @@ $ kubectl apply -f examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.
```bash
$ kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-mnist
```

```
NAME READY STATUS RESTARTS AGE
jaxjob-mnist-worker-0 0/1 Completed 0 108m
jaxjob-mnist-worker-1 0/1 Completed 0 108m
```

---
```bash
$ PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-mnist,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o
name -n kubeflow)
$ kubectl logs -f ${PODNAME} -n kubeflow
```

```
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/jax_example_data/
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/jax_example_data/
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/jax_example_data/
JAX global devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=131072), CpuDevice(id=131073), CpuDevice(id=131074), CpuDevice(id=131075), CpuDevice(id=131076), CpuDevice(id=131077), CpuDevice(id=131078), CpuDevice(id=131079)]
JAX local devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
JAX device count:16
JAX local device count:8
JAX process count:2
Epoch 0 in 1809.25 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
Epoch 1 in 0.51 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
Epoch 2 in 0.69 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
Epoch 3 in 0.81 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
Epoch 4 in 0.91 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
Epoch 5 in 0.97 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
Epoch 6 in 1.12 sec
Training set accuracy 0.09035000205039978
Test set accuracy 0.08919999748468399
Epoch 7 in 1.11 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
Epoch 8 in 1.21 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
Epoch 9 in 1.29 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09799999743700027
```

---

```bash
$ kubectl get -o yaml jaxjobs jaxjob-mnist -n kubeflow
```

```
apiVersion: kubeflow.org/v1
kind: JAXJob
metadata:
annotations:
kubectl.kubernetes.io/last-applied-configuration: |
{"apiVersion":"kubeflow.org/v1","kind":"JAXJob","metadata":{"annotations":{},"name":"jaxjob-mnist","namespace":"kubeflow"},"spec":{"jaxReplicaSpecs":{"Worker":{"replicas":2,"restartPolicy":"OnFailure","template":{"spec":{"containers":[{"image":"docker.io/sandipanify/jaxjob-spmd-mnist:latest","imagePullPolicy":"Always","name":"jax"}]}}}}}}
creationTimestamp: "2024-12-18T16:47:28Z"
generation: 1
name: jaxjob-mnist
namespace: kubeflow
resourceVersion: "3620"
uid: 15f1db77-3326-405d-95e6-3d9a0d581611
spec:
jaxReplicaSpecs:
Worker:
replicas: 2
restartPolicy: OnFailure
template:
spec:
containers:
- image: docker.io/sandipanify/jaxjob-spmd-mnist:latest
imagePullPolicy: Always
name: jax
status:
completionTime: "2024-12-18T17:22:11Z"
conditions:
- lastTransitionTime: "2024-12-18T16:47:28Z"
lastUpdateTime: "2024-12-18T16:47:28Z"
message: JAXJob jaxjob-mnist is created.
reason: JAXJobCreated
status: "True"
type: Created
- lastTransitionTime: "2024-12-18T16:50:57Z"
lastUpdateTime: "2024-12-18T16:50:57Z"
message: JAXJob kubeflow/jaxjob-mnist is running.
reason: JAXJobRunning
status: "False"
type: Running
- lastTransitionTime: "2024-12-18T17:22:11Z"
lastUpdateTime: "2024-12-18T17:22:11Z"
message: JAXJob kubeflow/jaxjob-mnist successfully completed.
reason: JAXJobSucceeded
status: "True"
type: Succeeded
replicaStatuses:
Worker:
selector: training.kubeflow.org/job-name=jaxjob-mnist,training.kubeflow.org/operator-name=jaxjob-controller,training.kubeflow.org/replica-type=worker
succeeded: 2
startTime: "2024-12-18T16:47:28Z"
```
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,4 @@ spec:
containers:
- name: jax
image: docker.io/sandipanify/jaxjob-spmd-mnist:latest
command:
- "python3"
- "spmd_mnist_classifier_fromscratch.py"
imagePullPolicy: Always
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2024 kubeflow.org.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An MNIST example with single-program multiple-data (SPMD) data parallelism.
The aim here is to illustrate how to use JAX's `pmap` to express and execute
Expand Down Expand Up @@ -34,9 +48,9 @@

process_id = int(os.getenv("PROCESS_ID"))
num_processes = int(os.getenv("NUM_PROCESSES"))
coordinator_address = os.getenv("COORDINATOR_ADDRESS")
coordinator_port = int(os.getenv("COORDINATOR_PORT"))
coordinator_address = f"{coordinator_address}:{coordinator_port}"
coordinator_address = (
f"{os.getenv('COORDINATOR_ADDRESS')}:{int(os.getenv('COORDINATOR_PORT'))}"
)

jax.distributed.initialize(
coordinator_address=coordinator_address,
Expand Down Expand Up @@ -138,6 +152,7 @@ def replicate_array(x):

print(f"JAX device count:{jax.device_count()}")
print(f"JAX local device count:{jax.local_device_count()}")
print(f"JAX process count:{jax.process_count()}")

for epoch in range(num_epochs):
start_time = time.time()
Expand Down
3 changes: 1 addition & 2 deletions sdk/python/test/e2e/test_e2e_jaxjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def generate_jaxjob(
def generate_container() -> V1Container:
return V1Container(
name=CONTAINER_NAME,
image="docker.io/kubeflow/jaxjob-simple:latest",
command=["python", "train.py"],
image="docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest",
resources=V1ResourceRequirements(limits={"memory": "2Gi", "cpu": "0.8"}),
)

0 comments on commit e04bcf9

Please sign in to comment.