Skip to content

Commit

Permalink
update wandb calls
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielilharco committed Sep 28, 2023
1 parent a3e8211 commit d1a0f8d
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions src/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,17 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist
}
log_data.update({name:val.val for name,val in losses_m.items()})

for name, val in log_data.items():
name = "train/" + name
if tb_writer is not None:
tb_writer.add_scalar(name, val, step)
if args.wandb:
assert wandb is not None, 'Please install wandb.'
wandb.log({name: val, 'step': step}, step=step)
log_data = {"train/" + name: val for name, val in log_data.items()}

if tb_writer is not None:
for name, val in log_data.items():
tb_writer.add_scalar(name, val, step)

if args.wandb:
assert wandb is not None, 'Please install wandb.'
log_data['step'] = step # for backwards compatibility
wandb.log(log_data, step=step)

# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()
Expand Down Expand Up @@ -317,19 +320,27 @@ def evaluate(model, data, epoch, args, tb_writer=None):
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
)

log_data = {"val/" + name: val for name, val in metrics.items()}

if args.save_logs:
for name, val in metrics.items():
if tb_writer is not None:
tb_writer.add_scalar(f"val/{name}", val, epoch)
if tb_writer is not None:
for name, val in log_data.items():
tb_writer.add_scalar(name, val, epoch)

with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
f.write(json.dumps(metrics))
f.write("\n")

if args.wandb:
assert wandb is not None, 'Please install wandb.'
for name, val in metrics.items():
wandb.log({f"val/{name}": val, 'epoch': epoch})
if 'train' in data:
dataloader = data['train'].dataloader
num_batches_per_epoch = dataloader.num_batches // args.accum_freq
step = num_batches_per_epoch * epoch
else:
step = None
log_data['epoch'] = epoch
wandb.log(log_data, step=step)

return metrics

Expand Down

0 comments on commit d1a0f8d

Please sign in to comment.