Simplest way to do data-parallel training #4300
Replies: 3 comments 3 replies
-
Hi @kriscao-cohere, I'm curious about this too - did you look at Assuming you already checked the test examples but I don't think I see a simple SPMD example there either. |
Beta Was this translation helpful? Give feedback.
-
Okay, I tried on Colab TPU here (which is version v2-8), this seems to work: with from flax import nnx
from functools import partial
import jax.numpy as jnp
class CNN(nnx.Module):
"""A Simple CNN Model"""
def __init__(self, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
def __call__(self, x):
# print(f'In call, {type(x) = }')
x = self.avg_pool(nnx.relu(self.conv1(x)))
x = self.avg_pool(nnx.relu(self.conv2(x)))
x = x.reshape(x.shape[0], -1)
x = nnx.relu(self.linear1(x))
x = self.linear2(x)
return x
model = CNN(rngs=nnx.Rngs(0))
x = jnp.ones((8, 10, 28, 28, 1))
# nnx.pmap `in_axes` argument accepts a tuple of integers and nnx.StateAxes objects
# integers are just passed along to jax.pmap, while StateAxes designate separate axes for different state types
# here, means to map all state types (... == filter for "Everything") to None (meaning, broadcast across devices)
state_axes = nnx.StateAxes({ ...: None })
@nnx.pmap(in_axes=(state_axes, 0), out_axes=0, devices=jax.devices())
@nnx.jit
def fwd(model, x):
y = model(x)
return y
y = fwd(model, x)
# print(type(x))
# nnx.display(x.shape)
nnx.display(y.shape) # (8, 10, 10) |
Beta Was this translation helpful? Give feedback.
-
I'm adding import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax import nnx
from jax.experimental import mesh_utils
import matplotlib.pyplot as plt
# create a mesh + shardings
num_devices = jax.local_device_count()
mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh((num_devices,)), ('data',)
)
model_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec())
data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('data'))
# create model
class MLP(nnx.Module):
def __init__(self, din, dmid, dout, *, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
return self.linear2(nnx.relu(self.linear1(x)))
model = MLP(1, 64, 1, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adamw(1e-2))
# replicate state
state = nnx.state((model, optimizer))
state = jax.device_put(state, model_sharding)
nnx.update((model, optimizer), state)
# visualize model sharding
print('model sharding')
jax.debug.visualize_array_sharding(model.linear1.kernel.value)
@nnx.jit
def train_step(model: MLP, optimizer: nnx.Optimizer, x, y):
def loss_fn(model: MLP):
y_pred = model(x)
return jnp.mean((y - y_pred) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
return loss
def dataset(steps, batch_size):
for _ in range(steps):
x = np.random.uniform(-2, 2, size=(batch_size, 1))
y = 0.8 * x**2 + 0.1 + np.random.normal(0, 0.1, size=x.shape)
yield x, y
for step, (x, y) in enumerate(dataset(1000, 16)):
# shard data
x, y = jax.device_put((x, y), data_sharding)
# train
loss = train_step(model, optimizer, x, y)
if step == 0:
print('data sharding')
jax.debug.visualize_array_sharding(x)
if step % 100 == 0:
print(f'step={step}, loss={loss}')
# dereplicate state
state = nnx.state((model, optimizer))
state = jax.device_get(state)
nnx.update((model, optimizer), state)
X, Y = next(dataset(1, 1000))
x_range = np.linspace(X.min(), X.max(), 100)[:, None]
y_pred = model(x_range)
# plot
plt.scatter(X, Y, label='data')
plt.plot(x_range, y_pred, color='black', label='model')
plt.legend()
plt.show() |
Beta Was this translation helpful? Give feedback.
-
I'm a long time JAX user, and recently I've started trying to get used to NNX. What is currently the simplest way to do data-parallel distributed training (no model sharding)? In the past I could just replicate all my model code to all devices with
jax.device_put_replicated
and usejax.pmap
to automatically take care of the outer device dimension. However, I'm unsure of how to usennx.jit
to do the same thing. Any pointers would be gratefully received, as I've read https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html and am still none the wiser.Naively attempting to replicate all my model params and a input batch to every device (with a jax.tree_map on the output of nnx.split) fails, as there is a dimension mismatch with my embedding module.
EDIT: in particular, it's very inconvenient to deal with an explicit leading
device
dimension in all of my tensors in my modules, as some modules (such as the jax_flash_attention package) expect tensors of a particular shape. What I really want to do is just topmap
this over the leading device dimension, but I don't know if such a thing is possible inside an nnx module.Beta Was this translation helpful? Give feedback.
All reactions