Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Step-based checkpointing in torchtune #2105

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

joecummings
Copy link
Contributor

@joecummings joecummings commented Dec 3, 2024

Enabling step-based checkpointing in torchtune

Original context: #2070

What are we currently doing?

We currently only checkpoint at epoch boundaries. That means a fine-tuning run has to iterate through all data in a dataset before saving a checkpoint. That's a problem when GPUs (especially interconnected GPUs) can fail frequently, losses can diverge, and datasets keep getting larger and larger.

We provide a tiny amount of flexibility by allowing the user to specify max_steps_per_epoch, so they can short-circuit the epoch and save sooner. In addition, it's always possible to split a dataset into chunks and train over them independently, resuming from training to simulate a larger training run.

Both of these "hacks" are not ideal and we've had users continually asking if they can control checkpointing based on number of training steps. (#988, #1107)

What does step-based checkpointing look like for the user?

I think the best way to do this would to show an example. Let's take our Llama3 8B single device full fine-tuning recipe, which utilizes the Alpaca dataset. The Alpaca dataset has ~52k samples. Using a batch size of 2 and a gradient accumulation of 16 steps, we can estimate around 1625 steps in this training run. Let's save a checkpointing every 500 steps!

From the config, we can specify:

save_every_n_train_steps: 500

And in our output directory, we can expect to see something like this:

output_dir/
	step_500/
		llama3_8b_single_device.yaml
		config.json
		model-0000-of-0003.safetensors
		...
	step_1000/
		...
	step_1500/
		...
	step_1625/
		...

You'll see that at the end of the training loop, a final checkpoint is saved regardless of how many steps have passed since the last checkpoint was saved.

At this point you might be saying: @joecummings, do you think memory grows on trees? Do you think we all drive Bugattis and smash up Grace Hopper machines for fun? Each Llama3 8B model is roughly 16 GB of memory and we've saved 4 copies of that in addition to the base model we used. That's 80 GB just for checkpoints! Not even to mention if we wanted to save the optimizer states, too...

Introducing:

keep_last_n_checkpoints: 1

This param will prune all the checkpoints except for the last N specified, leaving you with just the checkpoints you're interested in:

output_dir/
	step_1625/
		llama3_8b_single_device.yaml
		config.json
		model-0000-of-0003.safetensors
		...

What about the concept of epochs?

The concept of epochs will stay as a way to control how long training runs, as will the possibility to shorten training using max_steps_per_epoch; however, checkpointing will be entirely handled by a specification of steps.

Will this slow down training?

Great question! Checkpointing can take a long time, especially if saving the optimizer state for resuming training at a later date. For single device recipes, this likely isn't a huge issue, but for distributed recipes where the state dict needs to be collected on rank zero before saving, this can be verrrrrrry slow so anything that increases the frequency of checkpointing will increase the time it takes for training to complete. There are two ways to mitigate this:

  1. Specify a longer period so you save checkpoints less frequently
  2. Look into using DCP checkpointer for any intermediate checkpoints, which drastically reduces the time it takes to save. See Faster intermediate checkpoints with DCP async save in TorchTune #2006 for more information on this.

What changes need to be made in code?

In the recipe:

num_steps = 0
for curr_epoch in range(self.epochs_run, self.total_epochs):
	steps_run_since_last_saved_ckpt = 0
	for idx, batch in enumerate(self._dataloader):
		# Model forward
		...
		if doing_bwd_pass:
			steps_run_since_last_saved_ckpt += 1
			num_steps += 1

		if steps_run_since_last_saved_ckpt == save_ckpt_every_n_steps:
			self.save_checkpoint(ckpt, step=num_steps)
			steps_run_since_last_saved_ckpt = 0

# One final save
self.save_checkpoint(ckpt, step=num_steps)

And in the checkpointer:

def save_checkpoint(ckpt, step):
	# Prune old checkpoints if needed
	checkpoints_already_saved = get_all_prev_checkpoint_paths(self.output_dir)
	if len(checkpoints_already_saved) >= self.keep_last_n_checkpoints:
		prune_old_checkpoints(checkpoints_already_saved)

	# Create new folder for path
	new_ckpt_path = output_dir / step_{step}
	new_ckpt_path.mkdir(exist_ok=False)
	# Save new checkpoint
	torch.save(ckpt, new_ckpt_path / "model.bin")

Inspiration from relevant repositories:

Copy link

pytorch-bot bot commented Dec 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2105

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c4d7a93 with merge base efa91bf (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 3, 2024
@joecummings joecummings added the rfc Request for comments label Dec 3, 2024
@felipemello1
Copy link
Contributor

It makes sense to me.

I am not 100% sure how i feel about keeping epochs AND steps, since defaulting to one might simplify things. Internally, we only have the concept of steps, and it is up to the user to use bsz + len(dataset) to calculate how many steps == 1 epoch (which is a bit annoying).

Can we have a config showing the args for streaming dataset and checkpointing? If it looks simple, i like keeping both

regarding keep_last_n_checkpoints=1, my only question is if there's a safer alternative than deleting files. Probably not...

@calvinpelletier
Copy link
Contributor

Yay step-based checkpointing! Some thoughts:

  1. I second Felipe's comment about dropping support for epoch-based checkpointing. Our code will be cleaner and simpler if our whole ecosystem of checkpointing/validating/logging/etc is only step-based. The users who prefer epochs can easily calculate the number of steps in an epoch.

  2. Regarding keeping the last n checkpoints, should we also allow keeping the best checkpoint (once we add support for validation)? When I run training jobs for large models, I don't want to save too many checkpoint and run out of disk space, but I also don't want to have training collapse overnight and find out that all the good checkpoints have been already deleted. Since we don't have validation/evaluation during training yet, we could choose the best based on the training loss for now.

  3. There's this annoying trade-off when choosing the steps intervals for checkpointing/validating/etc. Large intervals are nice because then you don't slow training down too much. But small intervals are nice because you get quick feedback that everything is working correctly and the extra data points are important early on when things are changing pretty quickly. Personally, I use a logarithmic step-based checkpointing/validating/etc. schedule, where the interval is small during early training and gets longer as training goes on. I'm not sure if other people like this too, but it's trivial to implement and supporting it wouldn't complicate our code very much.

@joecummings
Copy link
Contributor Author

I am not 100% sure how i feel about keeping epochs AND steps, since defaulting to one might simplify things. Internally, we only have the concept of steps, and it is up to the user to use bsz + len(dataset) to calculate how many steps == 1 epoch (which is a bit annoying).

I totally get this. I might; however, be annoying and punt this to a different discussion b/c right now I am keeping epochs as a method to determine how long to train the model, whereas steps only determine when to checkpoint. I'd rather give more time to properly educate users if we do end up taking away epochs entirely.

regarding keep_last_n_checkpoints=1, my only question is if there's a safer alternative than deleting files. Probably not...

Yeah, this makes me a little uncomfortable, too, but it does seem to be the standard in other training code like Titan and TNT.

@joecummings
Copy link
Contributor Author

@calvinpelletier

Regarding keeping the last n checkpoints, should we also allow keeping the best checkpoint (once we add support for validation)? When I run training jobs for large models, I don't want to save too many checkpoint and run out of disk space, but I also don't want to have training collapse overnight and find out that all the good checkpoints have been already deleted. Since we don't have validation/evaluation during training yet, we could choose the best based on the training loss for now.

Yes, I think this is a natural extension to this work. If you look at the TNT code, they actual have some scaffolding to do this. I'd be happy to add that as a follow-up or even have this be community contributed.

There's this annoying trade-off when choosing the steps intervals for checkpointing/validating/etc. Large intervals are nice because then you don't slow training down too much. But small intervals are nice because you get quick feedback that everything is working correctly and the extra data points are important early on when things are changing pretty quickly. Personally, I use a logarithmic step-based checkpointing/validating/etc. schedule, where the interval is small during early training and gets longer as training goes on. I'm not sure if other people like this too, but it's trivial to implement and supporting it wouldn't complicate our code very much.

This makes a ton of sense and how I would want to incorporate this information now is to make sure it's easy to extend to such a use case when the time comes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. rfc Request for comments
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants