Skip to content

Commit

Permalink
readded pmap back cuz shard_map doeesn't work
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 4, 2024
1 parent faee9d4 commit 6b87ce1
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,6 @@ def _define_train_step(self, **kwargs):
loss_fn = self.loss_fn
distributed_training = self.distributed_training

@jax.jit
def train_step(train_state: SimpleTrainState, batch, rng_state: RandomMarkovState, local_device_indexes):
"""Train for a single step."""
images = batch['image']
Expand All @@ -541,8 +540,10 @@ def model_loss(params):
return train_state, loss, rng_state

if distributed_training:
# train_step = jax.pmap(axis_name="data")(train_step)
train_step = shard_map(train_step, mesh=self.mesh, in_specs=P('data'), out_specs=P())
train_step = jax.pmap(axis_name="data")(train_step)
# train_step = shard_map(train_step, mesh=self.mesh, in_specs=P('data'), out_specs=P())
else:
train_step = jax.jit(train_step)

return train_step

Expand Down Expand Up @@ -612,7 +613,10 @@ def fit(self, data, steps_per_epoch, epochs, train_step_args={}):
with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
for i in range(steps_per_epoch):
batch = next(train_ds)

if self.distributed_training and device_count > 1:
batch = jax.tree.map(lambda x: x.reshape(
(device_count, -1, *x.shape[1:])), batch)

train_state, loss, rng_state = train_step(train_state, batch, rng_state, local_device_indexes)

if self.distributed_training:
Expand Down Expand Up @@ -757,7 +761,7 @@ def _define_train_step(self, batch_size, null_labels_seq, text_embedder):

distributed_training = self.distributed_training

@jax.jit
# @jax.jit
def train_step(train_state: TrainState, batch, rng_state: RandomMarkovState, local_device_index):
"""Train for a single step."""
images = batch['image']
Expand Down Expand Up @@ -807,7 +811,10 @@ def model_loss(params):
return train_state, loss, rng_state

if distributed_training:
train_step = shard_map(train_step, mesh=self.mesh, in_specs=P('data'), out_specs=P())
train_step = jax.pmap(axis_name="data")(train_step)
# train_step = shard_map(train_step, mesh=self.mesh, in_specs=P('data'), out_specs=P())
else:
train_step = jax.jit(train_step)

return train_step

Expand Down

0 comments on commit 6b87ce1

Please sign in to comment.