From 3ffa9b974e30b0f1ef56f88ae5a9db213fe6fdc9 Mon Sep 17 00:00:00 2001 From: Gregory Kielian Date: Tue, 16 Apr 2024 08:17:14 -0700 Subject: [PATCH] Fix multigpu training for train.py script --- train.py | 54 +++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/train.py b/train.py index 1960d0faf..1c0c9261d 100644 --- a/train.py +++ b/train.py @@ -44,8 +44,8 @@ def parse_args(): # Data args training_group.add_argument('--dataset', default='shakespeare_char', type=str) - training_group.add_argument('--gradient_accumulation_steps', default=1, type=int) - training_group.add_argument('--batch_size', default=64, type=int) + training_group.add_argument('--gradient_accumulation_steps', default=40, type=int) + training_group.add_argument('--batch_size', default=12, type=int) training_group.add_argument("--seed", default=1337, type=int) # Model args @@ -207,9 +207,10 @@ def parse_args(): class Trainer: - def __init__(self, args, model_group): + def __init__(self, args, model_group, training_group): self.args = args self.model_group = model_group + self.training_group = training_group self.setup() def setup(self): @@ -217,22 +218,21 @@ def setup(self): self.ddp = int(os.environ.get('RANK', -1)) != -1 if self.ddp: init_process_group(backend=self.args.backend) - print(self.args) self.ddp_rank = int(os.environ['RANK']) self.ddp_local_rank = int(os.environ['LOCAL_RANK']) self.ddp_world_size = int(os.environ['WORLD_SIZE']) self.device = f'cuda:{self.ddp_local_rank}' - print("this is my device", self.device) torch.cuda.set_device(self.device) - self.master_process = self.ddp_rank == 0 + self.master_process = (self.ddp_rank == 0) self.seed_offset = self.ddp_rank - self.gradient_accumulation_steps //= self.ddp_world_size + self.args.gradient_accumulation_steps //= self.ddp_world_size else: self.device = self.args.device self.master_process = True self.seed_offset = 0 self.ddp_world_size = 1 + self.tokens_per_iter = self.args.gradient_accumulation_steps * self.ddp_world_size * self.args.batch_size * self.args.block_size if self.master_process: @@ -492,17 +492,44 @@ def train(self): if self.iter_num == 0 and self.args.eval_only: break + loss = None for micro_step in range(self.args.gradient_accumulation_steps): if self.ddp: self.model.require_backward_grad_sync = (micro_step == self.args.gradient_accumulation_steps - 1) - + with self.ctx: logits, loss = self.model(self.X, self.Y) loss = loss / self.args.gradient_accumulation_steps - - self.X, self.Y = self.get_batch('train') - + self.scaler.scale(loss).backward() + + if micro_step == self.args.gradient_accumulation_steps - 1: + if self.args.grad_clip != 0.0: + self.scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip) + + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad(set_to_none=True) + + self.X, self.Y = self.get_batch('train') + + if loss is not None: # Check if loss has a valid value + lossf = loss.item() * self.args.gradient_accumulation_steps + lossf = loss.item() * self.args.gradient_accumulation_steps + t1 = time.time() + dt = t1 - t0 + t0 = t1 + if self.iter_num % self.args.log_interval == 0 and self.master_process: + if local_iter_num >= 5: + mfu = self.raw_model.estimate_mfu(self.args.batch_size * self.args.gradient_accumulation_steps, dt) + running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu + print(f"iter {self.iter_num}: loss {lossf:.4f}, time {dt*1000:.2f} ms, mfu {running_mfu*100:.2f}%") + if math.isnan(lossf): + sys.exit("Exiting training loss is NaN") + self.log_metrics_non_validation(lossf, running_mfu, self.iter_num) + else: + print(f"Warning: loss is None at iteration {self.iter_num}") if self.args.grad_clip != 0.0: self.scaler.unscale_(self.optimizer) @@ -553,8 +580,9 @@ def train(self): wandb.finish() def main(): - args, model_group, _, _ = parse_args() - trainer = Trainer(args, model_group) + args, model_group, training_group, _ = parse_args() + trainer = Trainer(args, model_group, training_group) + trainer.train() if trainer.ddp: