Skip to content

Commit

Permalink
Update for the newer JAX version (0.4.12)
Browse files Browse the repository at this point in the history
  • Loading branch information
che-shr-cat committed Jun 20, 2023
1 parent 3274899 commit 281883e
Show file tree
Hide file tree
Showing 2 changed files with 421 additions and 297 deletions.
35 changes: 24 additions & 11 deletions Chapter-8/JAX_in_Action_Chapter_8_pjit.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
{
"cells": [
{
"cell_type": "code",
"source": [
"# Install the latest JAXlib version.\n",
"# This code was originally tested on the 0.4.3 version, then updated to the 0.4.12\n",
"!pip install --upgrade -q pip jax jaxlib"
],
"metadata": {
"id": "u5zBfJFH2Uyt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -534,8 +547,8 @@
},
"outputs": [],
"source": [
"from jax.experimental.maps import Mesh\n",
"from jax.experimental import PartitionSpec"
"from jax.sharding import Mesh\n",
"from jax.sharding import PartitionSpec"
]
},
{
Expand Down Expand Up @@ -1150,8 +1163,8 @@
},
"outputs": [],
"source": [
"from jax.interpreters.pxla import PartitionSpec as P # could be useful to reduce typing\n",
"from jax.experimental.maps import Mesh\n",
"from jax.sharding import PartitionSpec as P # could be useful to reduce typing\n",
"from jax.sharding import Mesh\n",
"import numpy as np"
]
},
Expand Down Expand Up @@ -1434,7 +1447,7 @@
{
"cell_type": "markdown",
"source": [
"Install these modules if you created a new empty cloud machine "
"Install these modules if you created a new empty cloud machine"
],
"metadata": {
"id": "wFW5RFD2NqOI"
Expand Down Expand Up @@ -1728,7 +1741,7 @@
"\n",
"data, info = tfds.load(name=\"mnist\",\n",
" data_dir=data_dir,\n",
" as_supervised=True, \n",
" as_supervised=True,\n",
" with_info=True)\n",
"\n",
"data_train = data['train']\n",
Expand All @@ -1746,7 +1759,7 @@
"HEIGHT = 28\n",
"WIDTH = 28\n",
"CHANNELS = 1\n",
"NUM_PIXELS = HEIGHT * WIDTH * CHANNELS \n",
"NUM_PIXELS = HEIGHT * WIDTH * CHANNELS\n",
"NUM_LABELS = info.features['label'].num_classes\n",
"NUM_DEVICES = jax.device_count()\n",
"BATCH_SIZE = 32"
Expand Down Expand Up @@ -1846,7 +1859,7 @@
" \"\"\"Initialize all layers for a fully-connected neural network with given sizes\"\"\"\n",
"\n",
" def random_layer_params(m, n, key, scale=1e-2):\n",
" \"\"\"A helper function to randomly initialize weights and biases of a dense layer\"\"\" \n",
" \"\"\"A helper function to randomly initialize weights and biases of a dense layer\"\"\"\n",
" w_key, b_key = random.split(key)\n",
" return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n",
"\n",
Expand All @@ -1870,7 +1883,7 @@
" for w, b in params[:-1]:\n",
" outputs = jnp.dot(w, activations) + b\n",
" activations = swish(outputs)\n",
" \n",
"\n",
" final_w, final_b = params[-1]\n",
" logits = jnp.dot(final_w, activations) + final_b\n",
" return logits\n",
Expand All @@ -1896,8 +1909,8 @@
"outputs": [],
"source": [
"from jax.experimental.pjit import pjit\n",
"from jax.interpreters.pxla import PartitionSpec as P\n",
"from jax.experimental.maps import Mesh\n",
"from jax.sharding import PartitionSpec as P\n",
"from jax.sharding import Mesh\n",
"import numpy as np"
]
},
Expand Down
Loading

0 comments on commit 281883e

Please sign in to comment.