Skip to content

Commit

Permalink
compiled model state_dict() workaround
Browse files Browse the repository at this point in the history
Pytorch 2.0 adds '_orig_mod.' prefix to keys of state_dict() of compiled models.
For compatibility, we save state_dict() of the original model, which shares the
weights without the prefix.
  • Loading branch information
EIFY authored and rwightman committed Sep 22, 2023
1 parent 69073c0 commit 905fc54
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,13 @@ def main(args):
wandb.save(params_file)
logging.debug('Finished loading wandb.')

# Pytorch 2.0 adds '_orig_mod.' prefix to keys of state_dict() of compiled models.
# For compatibility, we save state_dict() of the original model, which shares the
# weights without the prefix.
original_model = model
if args.torchcompile:
logging.info('Compiling model...')
model = torch.compile(model)
model = torch.compile(original_model)

if 'train' not in data:
# If using int8, convert to inference mode.
Expand Down Expand Up @@ -426,7 +430,7 @@ def main(args):
checkpoint_dict = {
"epoch": completed_epoch,
"name": args.name,
"state_dict": model.state_dict(),
"state_dict": original_model.state_dict(),
"optimizer": optimizer.state_dict(),
}
if scaler is not None:
Expand Down

0 comments on commit 905fc54

Please sign in to comment.