You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am in the process of moving over some data parallel training code written with xmap to the new jit API with the intent to extend it to allow model-parallel training. The training step function which was previously xmapped, takes a batch of data, splits the data into microbatches and then performs gradient accumulation on the microbatches, performing a single pmean operation at the end of the accumulation loop to synchronize gradients across devices.
I've been trying to replicate this behavior with jit unsuccessfully (minimal code reproduction attached below). Every time jnp.mean(loss) is called in the accumulation loop, an all-reduce across all devices is performed, which I have been able to confirm with the jax profiler. I've tried sharding the batch and then re-sharding every microbatch within the accumulation loop but the compiled code seems to want to perform this all-reduce no matter the sharding annotations. Is there something I am missing with respect to the sharding annotations or is this a bug?
Thank you!
Code to reproduce:
fromfunctoolsimportpartialfromtimeimporttimeimportjaximportnumpyasnpfromjax.shardingimportMeshfromjax.shardingimportNamedShardingfromjax.shardingimportPartitionSpecasPimportjax.numpyasjnpfromtypingimportAnyfromjax.laximportwith_sharding_constraintfromtypingimportCallabledeftrain_step(
params: Any,
batch: jnp.array,
batch_spec: Any=None,
grad_fn: Callable=None,
dp_axis_size: int=None,
per_device_parallelism: int=None,
):
""" Computes loss/grads for a single batch of data, optionally with gradient accumulation """batch_size=jnp.shape(batch)[0]
microbatch_size=dp_axis_size*per_device_parallelismnum_micro_steps=batch_size//microbatch_sizeassertnum_micro_steps*microbatch_size==batch_size# reshape to add a microbatch dimensionbatch=batch.reshape((num_micro_steps, microbatch_size) +batch.shape[1:])
batch=with_sharding_constraint(
batch, batch_spec
) # keep dp sharding for microbatches# accumulate gradientsdefcumul_minibatch_step(carry, x_y):
cumul_loss, cumul_grads=carryminibatch=x_yloss, grads=grad_fn(to_bf16(params), minibatch)
cumul_grads=jax.tree_map(jnp.add, cumul_grads, grads)
return (cumul_loss+loss, cumul_grads), Nonegrad_init=to_bf16(jax.tree_util.tree_map(jnp.zeros_like, params))
(loss, grads), _=jax.lax.scan(
cumul_minibatch_step, init=(jnp.zeros(()), grad_init), xs=batch
)
metrics= {
"train/loss": loss,
"train/ppl": jnp.exp(loss),
}
returngrads, metricsif__name__=="__main__":
rng=jax.random.PRNGKey(23)
grad_acc_steps=64batch_size=512d_model=2048n_layer=16num_iter=10dp=8defto_bf16(t):
returnjax.tree_map(
lambdax: x.astype(jnp.bfloat16) ifx.dtype==jnp.float32elsex, t
)
# Setting up device meshmesh=Mesh(np.array(jax.devices()).reshape(dp), ("dp"))
# setup sharding for data parallelismbatch_sharding=NamedSharding(mesh, P("dp", None))
no_shard=NamedSharding(mesh, None)
microbatch_spec=NamedSharding(mesh, P(None, "dp", *(None,) * (1)))
param_spec=no_shardbatch_grad_spec=no_sharddefcreate_mini_model(rng):
# create mini model that does a sequence of matmul + residualparams=jax.random.normal(rng, shape=(n_layer, d_model, d_model))
returnparamsmodel_params=create_mini_model(rng)
deffwd(batch: jnp.array, params: jnp.array):
deflayer(x, param):
p=paramy=jnp.dot(x, p)
returny+x, Nonex, _=jax.lax.scan(layer, batch, params)
returnxparams=jax.device_put(model_params, param_spec)
defloss_fn(params, batch):
out=fwd(batch, params)
loss=jnp.mean(out)
returnlossgrad_fn=jax.value_and_grad(loss_fn, has_aux=False)
# compute per-device batch sizeper_device_parallelism=batch_size//dp//grad_acc_stepswithmesh:
train_step_dp=jax.jit(
partial(
train_step,
grad_fn=grad_fn,
per_device_parallelism=per_device_parallelism,
dp_axis_size=dp,
batch_spec=microbatch_spec,
),
in_shardings=(param_spec, batch_sharding),
out_shardings=(param_spec, no_shard),
)
init_batch=jax.numpy.ones(shape=(batch_size, d_model))
batch=jax.device_put(init_batch, batch_sharding)
grads, metrics=train_step_dp(params, batch)
start=time()
foriinrange(num_iter):
# create a pseudobatch of data and send to devicebatch=jax.numpy.ones(shape=(batch_size, d_model))
batch=jax.device_put(batch, batch_sharding)
grads, metrics=train_step_dp(params, batch)
grads[0].block_until_ready()
total_time=time() -startprint(
f"Global BS {batch_size} - accum steps {grad_acc_steps} - Num Executions {num_iter}"
)
print(f"Total Time: {total_time:.4f}s")
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi,
I am in the process of moving over some data parallel training code written with
xmap
to the newjit
API with the intent to extend it to allow model-parallel training. The training step function which was previously xmapped, takes a batch of data, splits the data into microbatches and then performs gradient accumulation on the microbatches, performing a singlepmean
operation at the end of the accumulation loop to synchronize gradients across devices.I've been trying to replicate this behavior with
jit
unsuccessfully (minimal code reproduction attached below). Every timejnp.mean(loss)
is called in the accumulation loop, an all-reduce across all devices is performed, which I have been able to confirm with the jax profiler. I've tried sharding the batch and then re-sharding every microbatch within the accumulation loop but the compiled code seems to want to perform this all-reduce no matter the sharding annotations. Is there something I am missing with respect to the sharding annotations or is this a bug?Thank you!
Code to reproduce:
What jax/jaxlib version are you using?
Which accelerator(s) are you using?
Beta Was this translation helpful? Give feedback.
All reactions