Skip to content

Commit

Permalink
feat: just added a sample config
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 5, 2024
1 parent 295b198 commit b0691de
Showing 1 changed file with 132 additions and 85 deletions.
217 changes: 132 additions & 85 deletions training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Tuple, Mapping, Callable, List, Dict
from functools import partial
from flax.metrics import tensorboard
import flax.jax_utils
import jax.experimental.multihost_utils
import orbax
import orbax.checkpoint
import flax.jax_utils
Expand Down Expand Up @@ -43,6 +44,8 @@
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
from termcolor import colored

#####################################################################################################################
################################################# Initialization ####################################################
Expand Down Expand Up @@ -75,6 +78,17 @@ def get_random_key(self):
rng, subkey = jax.random.split(self.rng)
return RandomMarkovState(rng), subkey

PROCESS_COLOR_MAP = {
0: "green",
1: "yellow",
2: "magenta",
3: "cyan",
4: "white",
5: "light_blue",
6: "light_red",
7: "light_cyan"
}

#####################################################################################################################
################################################## Data Pipeline ####################################################
#####################################################################################################################
Expand Down Expand Up @@ -244,6 +258,14 @@ def map(self, element) -> Dict[str, jnp.array]:
"source": data_source_gcs(),
"augmenter": gcs_augmenters,
},
"laiona_coco": {
"source": data_source_gcs(),
"augmenter": gcs_augmenters,
},
"aesthetic_coyo": {
"source": data_source_gcs(),
"augmenter": gcs_augmenters,
},
}


Expand Down Expand Up @@ -316,6 +338,30 @@ def get_trainset():
############################################### Training Pipeline ###################################################
#####################################################################################################################

def _build_global_shape_and_sharding(
local_shape: tuple[int, ...], global_mesh: Mesh
) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
return global_shape, sharding


def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
"""Put local sharded array into local devices"""
global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
try:
local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
except ValueError as array_split_error:
raise ValueError(
f"Unable to put to devices shape {array.shape} with "
f"local device count {len(global_mesh.local_devices)} "
) from array_split_error
local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)

def convert_to_global_tree(global_mesh, pytree):
return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)

@struct.dataclass
class Metrics(metrics.Collection):
accuracy: metrics.Accuracy
Expand Down Expand Up @@ -350,6 +396,7 @@ def __init__(self,
# Auto-detect if we are running on multiple devices
distributed_training = jax.device_count() > 1
self.mesh = jax.sharding.Mesh(jax.devices(), 'data')
# self.sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('data'))

self.distributed_training = distributed_training
self.model = model
Expand All @@ -371,11 +418,13 @@ def __init__(self,
self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
self.wandb.define_metric("train/best_loss", step_metric="train/epoch")

checkpointer = orbax.checkpoint.PyTreeCheckpointer()
# checkpointer = orbax.checkpoint.PyTreeCheckpointer()
async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)

options = orbax.checkpoint.CheckpointManagerOptions(
max_to_keep=4, create=True)
self.checkpointer = orbax.checkpoint.CheckpointManager(
self.checkpoint_path() + checkpoint_suffix, checkpointer, options)
self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)

if load_from_checkpoint:
latest_epoch, old_state, old_best_state, rngstate = self.load()
Expand Down Expand Up @@ -442,36 +491,20 @@ def init_state(
):
self.best_loss = 1e9

if self.distributed_training:
devices = jax.local_devices()
if len(devices) > 1:
print("Replicating state across devices ", devices)
state = flax.jax_utils.replicate(state, devices)
best_state = flax.jax_utils.replicate(best_state, devices)
self.rngstate = flax.jax_utils.replicate(self.rngstate, devices)
else:
print("Not replicating any state, Only single device connected to the process")

self.state = state
self.best_state = best_state

def get_state(self):
if self.distributed_training and jax.process_index() == 0:
return flax.jax_utils.unreplicate(self.state)
else:
return self.state
# return fully_replicated_host_local_array_to_global_array()
return jax.tree_util.tree_map(lambda x : np.array(x), self.state)

def get_best_state(self):
if self.distributed_training and jax.process_index() == 0:
return flax.jax_utils.unreplicate(self.best_state)
else:
return self.best_state
# return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.best_state, jax.local_devices()))
return jax.tree_util.tree_map(lambda x : np.array(x), self.best_state)

def get_rngstate(self):
if self.distributed_training and jax.process_index() == 0:
return flax.jax_utils.unreplicate(self.rngstate)
else:
return self.rngstate
# return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.rngstate, jax.local_devices()))
return jax.tree_util.tree_map(lambda x : np.array(x), self.rngstate)

def checkpoint_path(self):
experiment_name = self.name
Expand Down Expand Up @@ -507,12 +540,13 @@ def save(self, epoch=0):
'rngs': self.get_rngstate(),
'state': self.get_state(),
'best_state': self.get_best_state(),
'best_loss': self.best_loss
'best_loss': np.array(self.best_loss),
}
try:
save_args = orbax_utils.save_args_from_target(ckpt)
self.checkpointer.save(epoch, ckpt, save_kwargs={
'save_args': save_args}, force=True)
self.checkpointer.wait_until_finished()
pass
except Exception as e:
print("Error saving checkpoint", e)
Expand All @@ -522,7 +556,7 @@ def _define_train_step(self, **kwargs):
loss_fn = self.loss_fn
distributed_training = self.distributed_training

def train_step(train_state: SimpleTrainState, batch, rng_state: RandomMarkovState, local_device_indexes):
def train_step(train_state: SimpleTrainState, rng_state: RandomMarkovState, batch, local_device_indexes):
"""Train for a single step."""
images = batch['image']
labels = batch['label']
Expand All @@ -540,11 +574,8 @@ def model_loss(params):
return train_state, loss, rng_state

if distributed_training:
train_step = jax.pmap(train_step, axis_name="data")
# train_step = shard_map(train_step, mesh=self.mesh, in_specs=P('data'), out_specs=P())
else:
train_step = jax.jit(train_step)

train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')), out_specs=(P(), P('data'), P()))
train_step = jax.pmap(train_step)
return train_step

def _define_compute_metrics(self):
Expand Down Expand Up @@ -577,6 +608,7 @@ def config(self):
}

def init_tensorboard(self, batch_size, steps_per_epoch, epochs):
from flax.metrics import tensorboard
summary_writer = tensorboard.SummaryWriter(self.tensorboard_path())
summary_writer.hparams({
**self.config(),
Expand All @@ -596,45 +628,57 @@ def fit(self, data, steps_per_epoch, epochs, train_step_args={}):
compute_metrics = self._define_compute_metrics()
train_state = self.state
rng_state = self.rngstate
device_count = jax.local_device_count()
# train_ds = flax.jax_utils.prefetch_to_device(train_ds, jax.devices())
global_device_count = jax.device_count()
local_device_count = jax.local_device_count()
process_index = jax.process_index()
if self.distributed_training:
local_device_indexes = jnp.arange(device_count)
global_device_indexes = jnp.arange(global_device_count)
else:
local_device_indexes = 0
global_device_indexes = 0

while self.latest_epoch < epochs:
self.latest_epoch += 1
current_epoch = self.latest_epoch
print(f"\nEpoch {current_epoch}/{epochs}")
start_time = time.time()
def train_loop(current_epoch, pbar: tqdm.tqdm, train_state, rng_state):
epoch_loss = 0

with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
for i in range(steps_per_epoch):
batch = next(train_ds)
if self.distributed_training and device_count > 1:
batch = jax.tree.map(lambda x: x.reshape(
(device_count, -1, *x.shape[1:])), batch)
train_state, loss, rng_state = train_step(train_state, batch, rng_state, local_device_indexes)

if self.distributed_training:
loss = jax.experimental.multihost_utils.process_allgather(loss)
loss = jnp.mean(loss)
current_step = 0
for i in range(steps_per_epoch):
batch = next(train_ds)
if self.distributed_training and global_device_count > 1:
# Convert the local device batches to a unified global jax.Array
batch = convert_to_global_tree(self.mesh, batch)
train_state, loss, rng_state = train_step(train_state, rng_state, batch, global_device_indexes)

if self.distributed_training:
loss = jax.experimental.multihost_utils.process_allgather(loss)
loss = jnp.mean(loss) # Just to make sure its a scaler value
epoch_loss += loss

epoch_loss += loss
if i % 100 == 0:
if pbar is not None:
if i % 10 == 0:
pbar.set_postfix(loss=f'{loss:.4f}')
pbar.update(100)
pbar.update(10)
current_step = current_epoch*steps_per_epoch + i
if self.wandb is not None:
self.wandb.log({
"train/step" : current_step,
"train/loss": loss,
}, step=current_step)
print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/steps_per_epoch}", 'green'))
return epoch_loss, current_step, train_state, rng_state

print(f"\n\tEpoch done")
while self.latest_epoch < epochs:
self.latest_epoch += 1
current_epoch = self.latest_epoch
print(f"\nEpoch {current_epoch}/{epochs}")
start_time = time.time()
epoch_loss = 0

if process_index == 0:
with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
epoch_loss, current_step, train_state, rng_state = train_loop(current_epoch, pbar, train_state, rng_state)
else:
epoch_loss, current_step, train_state, rng_state = train_loop(current_epoch, None, train_state, rng_state)
print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))

end_time = time.time()
self.state = train_state
self.rngstate = rng_state
Expand All @@ -645,21 +689,17 @@ def fit(self, data, steps_per_epoch, epochs, train_step_args={}):
self.best_loss = avg_loss
self.best_state = train_state
self.save(current_epoch)
if self.wandb is not None:
self.wandb.log({
"train/epoch_time": total_time,
"train/avg_time_per_step": avg_time_per_step,
"train/avg_loss": avg_loss,
"train/best_loss": self.best_loss,
"train/epoch": current_epoch,
}, step=current_step)

# Compute Metrics
metrics_str = ''

print(
f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss} {metrics_str}")


if process_index == 0:
if self.wandb is not None:
self.wandb.log({
"train/epoch_time": total_time,
"train/avg_time_per_step": avg_time_per_step,
"train/avg_loss": avg_loss,
"train/best_loss": self.best_loss,
"train/epoch": current_epoch,
}, step=current_step)
print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
self.save(epochs)
return self.state

Expand Down Expand Up @@ -746,7 +786,7 @@ def generate_states(
return state, best_state

def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
noise_schedule = self.noise_schedule
noise_schedule: NoiseScheduler = self.noise_schedule
model = self.model
model_output_transform = self.model_output_transform
loss_fn = self.loss_fn
Expand All @@ -762,8 +802,12 @@ def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
distributed_training = self.distributed_training

# @jax.jit
def train_step(train_state: TrainState, batch, rng_state: RandomMarkovState, local_device_index):
def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
"""Train for a single step."""
rng_state, subkey = rng_state.get_random_key()
subkey = jax.random.fold_in(subkey, local_device_index.reshape())
local_rng_state = RandomMarkovState(subkey)

images = batch['image']
# normalize image
images = (images - 127.5) / 127.5
Expand All @@ -777,11 +821,6 @@ def train_step(train_state: TrainState, batch, rng_state: RandomMarkovState, loc
label_seq = jnp.concat(
[null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)

rng_state, subkey = rng_state.get_random_key()
subkey = jax.random.fold_in(subkey, local_device_index)
subkey = jax.random.fold_in(subkey, jax.process_index())
local_rng_state = RandomMarkovState(subkey)

noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)

local_rng_state, rngs = local_rng_state.get_random_key()
Expand All @@ -806,14 +845,14 @@ def model_loss(params):
loss, grads = jax.value_and_grad(model_loss)(train_state.params)
if distributed_training:
grads = jax.lax.pmean(grads, "data")
loss = jax.lax.pmean(loss, "data")
train_state = train_state.apply_gradients(grads=grads)
train_state = train_state.apply_ema(self.ema_decay)
return train_state, loss, rng_state

if distributed_training:
train_step = jax.pmap(train_step, axis_name="data")
# train_step = shard_map(train_step, mesh=self.mesh, in_specs=P('data'), out_specs=P())
else:
train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
out_specs=(P(), P(), P()))
train_step = jax.jit(train_step)

return train_step
Expand Down Expand Up @@ -1048,7 +1087,15 @@ def main(args):
print(f"Training on {CONFIG['dataset']['name']} dataset with {batches} samples")

final_state = trainer.fit(data, batches, epochs=CONFIG['epochs'])

if __name__ == '__main__':
args = parser.parse_args()
main(args)

"""
JAX_TRACEBACK_FILTERING=off python3 training.py --dataset=laiona_coco --dataset_path='/home/mrwhite0racle/gcs_mount/arrayrecord/laion-aesthetics-12m+mscoco-2017'\
--epochs=40 --batch_size=128 \
--learning_rate=2.7e-4 --num_res_blocks=3 \
--use_self_and_cross=False --dtype=float32 --precision=high --attention_heads=16 \
--experiment_name='batch 128 multi-host laiona_coco'
"""

0 comments on commit b0691de

Please sign in to comment.