Skip to content

Commit

Permalink
feat: move to shard_map instead of pmap
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 4, 2024
1 parent 866f06d commit faee9d4
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@

import resource

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

#####################################################################################################################
################################################# Initialization ####################################################
#####################################################################################################################
Expand Down Expand Up @@ -345,6 +349,7 @@ def __init__(self,
if distributed_training is None or distributed_training is True:
# Auto-detect if we are running on multiple devices
distributed_training = jax.device_count() > 1
self.mesh = jax.sharding.Mesh(jax.devices(), 'data')

self.distributed_training = distributed_training
self.model = model
Expand Down Expand Up @@ -517,6 +522,7 @@ def _define_train_step(self, **kwargs):
loss_fn = self.loss_fn
distributed_training = self.distributed_training

@jax.jit
def train_step(train_state: SimpleTrainState, batch, rng_state: RandomMarkovState, local_device_indexes):
"""Train for a single step."""
images = batch['image']
Expand All @@ -530,14 +536,13 @@ def model_loss(params):
return loss
loss, grads = jax.value_and_grad(model_loss)(train_state.params)
if distributed_training:
grads = jax.lax.pmean(grads, "device")
grads = jax.lax.pmean(grads, "data")
train_state = train_state.apply_gradients(grads=grads)
return train_state, loss, rng_state

if distributed_training:
train_step = jax.pmap(axis_name="device")(train_step)
else:
train_step = jax.jit(train_step)
# train_step = jax.pmap(axis_name="data")(train_step)
train_step = shard_map(train_step, mesh=self.mesh, in_specs=P('data'), out_specs=P())

return train_step

Expand Down Expand Up @@ -607,12 +612,12 @@ def fit(self, data, steps_per_epoch, epochs, train_step_args={}):
with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
for i in range(steps_per_epoch):
batch = next(train_ds)
if self.distributed_training and device_count > 1:
batch = jax.tree.map(lambda x: x.reshape(
(device_count, -1, *x.shape[1:])), batch)

train_state, loss, rng_state = train_step(train_state, batch, rng_state, local_device_indexes)
loss = jnp.mean(loss)

if self.distributed_training:
loss = jax.experimental.multihost_utils.process_allgather(loss)
loss = jnp.mean(loss)

epoch_loss += loss
if i % 100 == 0:
Expand Down Expand Up @@ -752,6 +757,7 @@ def _define_train_step(self, batch_size, null_labels_seq, text_embedder):

distributed_training = self.distributed_training

@jax.jit
def train_step(train_state: TrainState, batch, rng_state: RandomMarkovState, local_device_index):
"""Train for a single step."""
images = batch['image']
Expand Down Expand Up @@ -795,15 +801,13 @@ def model_loss(params):

loss, grads = jax.value_and_grad(model_loss)(train_state.params)
if distributed_training:
grads = jax.lax.pmean(grads, "device")
grads = jax.lax.pmean(grads, "data")
train_state = train_state.apply_gradients(grads=grads)
train_state = train_state.apply_ema(self.ema_decay)
return train_state, loss, rng_state

if distributed_training:
train_step = jax.pmap(train_step, axis_name="device")
else:
train_step = jax.jit(train_step)
train_step = shard_map(train_step, mesh=self.mesh, in_specs=P('data'), out_specs=P())

return train_step

Expand Down

0 comments on commit faee9d4

Please sign in to comment.