Skip to content

Commit

Permalink
remove mid epoch save/load, fix between epoch save/load
Browse files Browse the repository at this point in the history
  • Loading branch information
StephAO committed Feb 16, 2020
1 parent 0960934 commit f59e83c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 51 deletions.
5 changes: 2 additions & 3 deletions data_utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,10 @@ def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None
def __len__(self):
return self.ds_len

def set_args(self, modes, past_iters):
def set_args(self, modes):
# TODO: full training defined by number of tokens seen - not by number of iterations
print("setting up args, modes:", modes)
self.modes = modes
self.past_iters = past_iters
self.split_percent = 1.0
self.corruption_rate = 0.
self.num_sent_per_seq = 1
Expand Down Expand Up @@ -545,7 +544,7 @@ def __getitem__(self, idx):
# get rng state corresponding to index (allows deterministic random pair)
if idx >= self.ds_len:
raise StopIteration
rng = random.Random(idx) #idx + self.past_iters)
rng = random.Random(idx)
self.idx = idx
# get sentence pair and label
sentence_labels = None
Expand Down
59 changes: 18 additions & 41 deletions pretrain_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,10 @@ def train_step(input_data, model, criterion, optimizer, lr_scheduler, modes, arg
losses_reduced, num_tokens = backward_step(optimizer, model, losses, num_tokens, args)
# Update parameters.
optimizer.step()
# Update learning rate.
skipped_iter = 0
return losses_reduced, skipped_iter, num_tokens
return losses_reduced, num_tokens


def train_epoch(epoch, model, optimizer, train_data, lr_scheduler, criterion, timers, experiment, metrics, past_iters, args):
def train_epoch(epoch, model, optimizer, train_data, lr_scheduler, criterion, timers, experiment, metrics, args):
"""Train one full epoch."""
print("Starting training of epoch {}".format(epoch), flush=True)
# Turn on training mode which enables dropout.
Expand All @@ -243,7 +241,6 @@ def train_epoch(epoch, model, optimizer, train_data, lr_scheduler, criterion, ti
tot_tokens = 0
iteration = 0
tot_iteration = 0
skipped_iters = 0
if args.resume_dataloader:
iteration = args.mid_epoch_iters
args.resume_dataloader = False
Expand All @@ -252,7 +249,7 @@ def train_epoch(epoch, model, optimizer, train_data, lr_scheduler, criterion, ti
modes = args.modes.split(',')
if args.incremental:
modes = modes[:epoch]
train_data.dataset.set_args(modes, past_iters)
train_data.dataset.set_args(modes)
data_iters = iter(train_data)

timers('interval time').start()
Expand All @@ -273,7 +270,7 @@ def train_epoch(epoch, model, optimizer, train_data, lr_scheduler, criterion, ti
modes_ = modes
while True:
try:
losses, skipped_iter, num_tokens = train_step(next(data_iters),
losses, num_tokens = train_step(next(data_iters),
model,
criterion,
optimizer,
Expand All @@ -288,8 +285,8 @@ def train_epoch(epoch, model, optimizer, train_data, lr_scheduler, criterion, ti

log_tokens += num_tokens.item()
tot_tokens += num_tokens.item()
# Update learning rate.
lr_scheduler.step(step_num=(epoch-1) * max_tokens + tot_tokens)
skipped_iters += skipped_iter
iteration += 1
# Update losses.
for mode, loss in losses.items():
Expand Down Expand Up @@ -329,14 +326,14 @@ def train_epoch(epoch, model, optimizer, train_data, lr_scheduler, criterion, ti
iteration = 0

# Checkpointing
if args.save and args.save_iters and iteration % args.save_iters == 0:
total_iters = args.train_iters * (epoch-1) + iteration
model_suffix = 'model/%d.pt' % (total_iters)
save_checkpoint(model_suffix, epoch, iteration, model, optimizer,
lr_scheduler, args)
# Currently unsupported, fix saving mid epoch tokens to fix
# if args.save and args.save_iters and iteration % args.save_iters == 0:
# total_iters = args.train_iters * (epoch-1) + iteration
# model_suffix = 'model/%d.pt' % (total_iters)
# save_checkpoint(model_suffix, epoch, iteration, model, optimizer,
# lr_scheduler, args)

print("Learnt using {} tokens over {} iterations this epoch".format(tot_tokens, tot_iteration + iteration))
return tot_iteration, skipped_iters

def evaluate(epoch, data_source, model, criterion, elapsed_time, args, test=False):
"""Evaluation."""
Expand All @@ -351,7 +348,6 @@ def evaluate(epoch, data_source, model, criterion, elapsed_time, args, test=Fals
modes = args.modes.split(',')
data_source.dataset.set_args(modes, 0)
data_iters = iter(data_source)
start_time = time.time()
with torch.no_grad():
iteration = 0
while tokens < max_tokens:
Expand Down Expand Up @@ -383,7 +379,6 @@ def evaluate(epoch, data_source, model, criterion, elapsed_time, args, test=Fals
total_losses[mode] = total_losses.get(mode, 0.0) + loss.data.detach().float().item()
iteration += 1
tokens += num_tokens.item()
#print("Done iteration", iteration, time.time() - start_time)

print("Evaluated using {} tokens over {} iterations.".format(tokens, iteration), flush=True)

Expand Down Expand Up @@ -475,37 +470,27 @@ def main():

#model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
timers("total time").start()
epoch = 0
# At any point you can hit Ctrl + C to break out of training early.
try:
total_iters = 0
skipped_iters = 0
start_epoch = 1
best_val_loss = float('inf')
# Resume data loader if necessary.
if args.resume_dataloader:
start_epoch = args.epoch
total_iters = args.total_iters
train_data.batch_sampler.start_iter = total_iters % len(train_data)
# For all epochs.
# For all epochs.
for epoch in range(start_epoch, args.epochs+1):
if args.shuffle:
train_data.batch_sampler.sampler.set_epoch(epoch+args.seed)
timers('epoch time').start()
iteration, skipped = train_epoch(epoch, model, optimizer,
train_data, lr_scheduler,
criterion, timers, experiment,
metrics, total_iters, args)

train_epoch(epoch, model, optimizer, train_data, lr_scheduler, criterion, timers, experiment, metrics, args)
elapsed_time = timers('epoch time').elapsed()
total_iters += iteration
skipped_iters += skipped


if args.save:
ck_path = 'ck/model_{}.pt'.format(epoch)
print('saving ck model to:',
os.path.join(args.save, ck_path))
save_checkpoint(ck_path, epoch+1, total_iters, model,
optimizer, lr_scheduler, args)
save_checkpoint(ck_path, epoch+1, model, optimizer, lr_scheduler, args)

val_loss = evaluate(epoch, val_data, model, criterion, elapsed_time, args)

Expand All @@ -515,26 +500,18 @@ def main():
best_path = 'best/model.pt'
print('saving best model to:',
os.path.join(args.save, best_path))
save_checkpoint(best_path, epoch+1, total_iters, model,
optimizer, lr_scheduler, args)
save_checkpoint(best_path, epoch+1, model, optimizer, lr_scheduler, args)


except KeyboardInterrupt:
print('-' * 100)
print('Exiting from training early')
if args.save and False: # WARNING I disabled this to save memory, but may be necessary in the future
cur_path = 'current/model.pt'
print('saving current model to:',
os.path.join(args.save, cur_path))
save_checkpoint(cur_path, epoch, total_iters, model, optimizer,
lr_scheduler, args)
exit()

if args.save and False: # WARNING I disabled this to save memory, but may be necessary in the future
final_path = 'final/model.pt'
print('saving final model to:', os.path.join(args.save, final_path))
save_checkpoint(final_path, args.epochs, total_iters, model, optimizer,
lr_scheduler, args)
save_checkpoint(final_path, args.epochs, model, optimizer, lr_scheduler, args)

if test_data is not None:
# Run on test data.
Expand Down
9 changes: 2 additions & 7 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, args):
checkpoint_path = args.load
model_path = checkpoint_path
model_sd = torch.load(model_path, map_location='cpu')
total_iters = model_sd['total_iters']
epoch = model_sd['epoch']
i = model_sd['mid_epoch_iters']
model.load_state_dict(model_sd['sd'])

checkpoint_path = os.path.dirname(checkpoint_path)
Expand All @@ -134,10 +132,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, args):
np.random.set_state(rng_state[2])
random.setstate(rng_state[3])

return epoch, i, total_iters
return epoch


def save_checkpoint(model_suffix, epoch, i, model, optimizer, lr_scheduler, args):
def save_checkpoint(model_suffix, epoch, model, optimizer, lr_scheduler, args):
"""Save a model checkpoint."""

model_path = os.path.join(args.save, model_suffix)
Expand All @@ -150,11 +148,8 @@ def save_checkpoint(model_suffix, epoch, i, model, optimizer, lr_scheduler, args
torch.distributed.get_rank() > 0):
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
total_iters = args.train_iters * (epoch-1) + i
sd = {'sd': model.state_dict()}
sd['total_iters'] = total_iters
sd['epoch'] = epoch
sd['mid_epoch_iters'] = i
torch.save(sd, model_path)
print('saved', model_path)

Expand Down

0 comments on commit f59e83c

Please sign in to comment.