diff --git a/.github/workflows/publish-example-images.yaml b/.github/workflows/publish-example-images.yaml index 5df38f1f37..5012714b57 100644 --- a/.github/workflows/publish-example-images.yaml +++ b/.github/workflows/publish-example-images.yaml @@ -74,7 +74,7 @@ jobs: platforms: linux/amd64 dockerfile: examples/pytorch/deepspeed-demo/Dockerfile context: examples/pytorch/deepspeed-demo - - component-name: jaxjob-mnist + - component-name: jaxjob-dist-spmd-mnist platforms: linux/amd64,linux/arm64 dockerfile: examples/jax/jax-dist-spmd-mnist/Dockerfile context: examples/jax/jax-dist-spmd-mnist/ diff --git a/examples/jax/jax-dist-spmd-mnist/Dockerfile b/examples/jax/jax-dist-spmd-mnist/Dockerfile index 805f222a35..92b406f117 100644 --- a/examples/jax/jax-dist-spmd-mnist/Dockerfile +++ b/examples/jax/jax-dist-spmd-mnist/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.12 +FROM python:3.13 RUN pip install --upgrade pip RUN pip install --upgrade jax absl-py @@ -25,3 +25,5 @@ RUN git clone https://github.com/facebookincubator/gloo.git \ WORKDIR /app ADD datasets.py spmd_mnist_classifier_fromscratch.py /app + +ENTRYPOINT ["python3", "spmd_mnist_classifier_fromscratch.py"] diff --git a/examples/jax/jax-dist-spmd-mnist/README.md b/examples/jax/jax-dist-spmd-mnist/README.md index 3f44afb615..6194ea9eda 100644 --- a/examples/jax/jax-dist-spmd-mnist/README.md +++ b/examples/jax/jax-dist-spmd-mnist/README.md @@ -16,6 +16,13 @@ $ kubectl apply -f examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo. ```bash $ kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-mnist ``` + +``` +NAME READY STATUS RESTARTS AGE +jaxjob-mnist-worker-0 0/1 Completed 0 108m +jaxjob-mnist-worker-1 0/1 Completed 0 108m +``` + --- ```bash $ PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-mnist,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o @@ -23,8 +30,104 @@ name -n kubeflow) $ kubectl logs -f ${PODNAME} -n kubeflow ``` +``` +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/ +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/jax_example_data/ +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/jax_example_data/ +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/jax_example_data/ +JAX global devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=131072), CpuDevice(id=131073), CpuDevice(id=131074), CpuDevice(id=131075), CpuDevice(id=131076), CpuDevice(id=131077), CpuDevice(id=131078), CpuDevice(id=131079)] +JAX local devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)] +JAX device count:16 +JAX local device count:8 +JAX process count:2 +Epoch 0 in 1809.25 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 1 in 0.51 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 2 in 0.69 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 3 in 0.81 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 4 in 0.91 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 5 in 0.97 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 6 in 1.12 sec +Training set accuracy 0.09035000205039978 +Test set accuracy 0.08919999748468399 +Epoch 7 in 1.11 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 8 in 1.21 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 9 in 1.29 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 + +``` + --- ```bash $ kubectl get -o yaml jaxjobs jaxjob-mnist -n kubeflow ``` + +``` +apiVersion: kubeflow.org/v1 +kind: JAXJob +metadata: + annotations: + kubectl.kubernetes.io/last-applied-configuration: | + {"apiVersion":"kubeflow.org/v1","kind":"JAXJob","metadata":{"annotations":{},"name":"jaxjob-mnist","namespace":"kubeflow"},"spec":{"jaxReplicaSpecs":{"Worker":{"replicas":2,"restartPolicy":"OnFailure","template":{"spec":{"containers":[{"image":"docker.io/sandipanify/jaxjob-spmd-mnist:latest","imagePullPolicy":"Always","name":"jax"}]}}}}}} + creationTimestamp: "2024-12-18T16:47:28Z" + generation: 1 + name: jaxjob-mnist + namespace: kubeflow + resourceVersion: "3620" + uid: 15f1db77-3326-405d-95e6-3d9a0d581611 +spec: + jaxReplicaSpecs: + Worker: + replicas: 2 + restartPolicy: OnFailure + template: + spec: + containers: + - image: docker.io/sandipanify/jaxjob-spmd-mnist:latest + imagePullPolicy: Always + name: jax +status: + completionTime: "2024-12-18T17:22:11Z" + conditions: + - lastTransitionTime: "2024-12-18T16:47:28Z" + lastUpdateTime: "2024-12-18T16:47:28Z" + message: JAXJob jaxjob-mnist is created. + reason: JAXJobCreated + status: "True" + type: Created + - lastTransitionTime: "2024-12-18T16:50:57Z" + lastUpdateTime: "2024-12-18T16:50:57Z" + message: JAXJob kubeflow/jaxjob-mnist is running. + reason: JAXJobRunning + status: "False" + type: Running + - lastTransitionTime: "2024-12-18T17:22:11Z" + lastUpdateTime: "2024-12-18T17:22:11Z" + message: JAXJob kubeflow/jaxjob-mnist successfully completed. + reason: JAXJobSucceeded + status: "True" + type: Succeeded + replicaStatuses: + Worker: + selector: training.kubeflow.org/job-name=jaxjob-mnist,training.kubeflow.org/operator-name=jaxjob-controller,training.kubeflow.org/replica-type=worker + succeeded: 2 + startTime: "2024-12-18T16:47:28Z" + +``` diff --git a/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml b/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml index 912ebde719..50bd66f583 100644 --- a/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml +++ b/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml @@ -13,7 +13,4 @@ spec: containers: - name: jax image: docker.io/sandipanify/jaxjob-spmd-mnist:latest - command: - - "python3" - - "spmd_mnist_classifier_fromscratch.py" imagePullPolicy: Always diff --git a/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py index 5982963ba7..41f55d745c 100644 --- a/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py +++ b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py @@ -1,3 +1,17 @@ +# Copyright 2024 kubeflow.org. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """An MNIST example with single-program multiple-data (SPMD) data parallelism. The aim here is to illustrate how to use JAX's `pmap` to express and execute @@ -34,9 +48,9 @@ process_id = int(os.getenv("PROCESS_ID")) num_processes = int(os.getenv("NUM_PROCESSES")) -coordinator_address = os.getenv("COORDINATOR_ADDRESS") -coordinator_port = int(os.getenv("COORDINATOR_PORT")) -coordinator_address = f"{coordinator_address}:{coordinator_port}" +coordinator_address = ( + f"{os.getenv('COORDINATOR_ADDRESS')}:{int(os.getenv('COORDINATOR_PORT'))}" +) jax.distributed.initialize( coordinator_address=coordinator_address, @@ -138,6 +152,7 @@ def replicate_array(x): print(f"JAX device count:{jax.device_count()}") print(f"JAX local device count:{jax.local_device_count()}") + print(f"JAX process count:{jax.process_count()}") for epoch in range(num_epochs): start_time = time.time() diff --git a/sdk/python/test/e2e/test_e2e_jaxjob.py b/sdk/python/test/e2e/test_e2e_jaxjob.py index 6223c8a988..98cc5fff49 100644 --- a/sdk/python/test/e2e/test_e2e_jaxjob.py +++ b/sdk/python/test/e2e/test_e2e_jaxjob.py @@ -155,7 +155,6 @@ def generate_jaxjob( def generate_container() -> V1Container: return V1Container( name=CONTAINER_NAME, - image="docker.io/kubeflow/jaxjob-simple:latest", - command=["python", "train.py"], + image="docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest", resources=V1ResourceRequirements(limits={"memory": "2Gi", "cpu": "0.8"}), )