-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix checkpoint saving #650
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
@@ -22,7 +22,7 @@ tabulate | |||
evaluate | |||
rouge_score | |||
pyyaml==6.0.1 | |||
faiss-gpu | |||
faiss-gpu; python_version < '3.11' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why python has to be less than 3.11, this is almost 6 months old stable python version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think your change needs to be updated here in pip packages
https://pypi.org/project/llama-recipes/#history which has been updated on 23rd july. I tested fine tuning notebook and currently its broken.
What does this PR do?
This PR
Fixes # (issue)
#646
Feature/Issue validation/testing
Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.
CUDA_VISIBLE_DEVICES=0,1,4,5 torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --model_name meta-llama/Meta-Llama-3.1-8B --use_peft --peft_method lora --output_dir ../llama_output/ --run_validation --save_model --samsum_dataset.trust_remote_code=True --context_length 2048 --max_train_step 1 --max_eval_step 1 cd recipes/quickstart/inference/local_inference cat samsum_prompt.txt | python inference.py --model_name meta-llama/Meta-Llama-3.1-70B-Instruct --peft_model ~/llama_output/
Logs for Test A
CUDA_VISIBLE_DEVICES=2,3,6,7 torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --model_name meta-llama/Meta-Llama-3.1-8B --output_dir ../llama_output/ --run_validation --save_model --samsum_dataset.trust_remote_code=True --context_length 2048 --max_train_step 1 --max_eval_step 1 --fsdp_config.checkpoint_type StateDictType.FULL_STATE_DICT --dist_checkpoint_root_folder ../llama_output_fsdp/
Logs for Test B
CUDA_VISIBLE_DEVICES=2,3,6,7 torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --model_name meta-llama/Meta-Llama-3.1-8B --output_dir ../llama_output/ --run_validation --save_model --samsum_dataset.trust_remote_code=True --context_length 2048 --max_train_step 1 --max_eval_step 1 --fsdp_config.checkpoint_type StateDictType.SHARDED_STATE_DICT --dist_checkpoint_root_folder ../llama_output_fsdp/
Logs for Test C
``
samsum_dataset.trust_remote_code=True --context_length 2048 --max_train_step 1 --max_eval_step 1 --fsdp_config.checkpoint_type StateDictType.SHARDED_STATE_DICT --dist_checkpoint_root_folder ../llama_output_fsdp/
W0828 12:27:59.588000 140014184924160 torch/distributed/run.py:779]
W0828 12:27:59.588000 140014184924160 torch/distributed/run.py:779] *****************************************
W0828 12:27:59.588000 140014184924160 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0828 12:27:59.588000 140014184924160 torch/distributed/run.py:779] *****************************************
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning:
torch.distributed._shard.checkpoint
will be deprecated, use `torch.distributed.checkpoint` insteadfrom torch.distributed._shard.checkpoint import (
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
from torch.distributed._shard.checkpoint import (
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
from torch.distributed._shard.checkpoint import (
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
from torch.distributed._shard.checkpoint import (
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 4.23it/s]
--> Model meta-llama/Meta-Llama-3.1-8B
--> meta-llama/Meta-Llama-3.1-8B has 8030.261248 Million params
bFloat16 enabled for mixed precision - using bfSixteen policy
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 4.11it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 4.13it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 4.03it/s]
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
Preprocessing dataset: 19%|████████████████████████████████████████████████████████▉ | 2846/14732 [00:00<00:03, 3524.81it/s]
--> Training Set Length = 14732
Preprocessing dataset: 50%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 7373/14732 [00:02<00:02, 3356.12it/s]
--> Validation Set Length = 818
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3282.43it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3324.44it/s]
--> Num of Validation Set Batches loaded = 17
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3415.49it/s]
Preprocessing dataset: 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 14219/14732 [00:04<00:00, 3262.10it/s]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
warnings.warn(
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3461.39it/s]
--> Num of Validation Set Batches loaded = 17
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3162.09it/s]
Preprocessing dataset: 62%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 9177/14732 [00:02<00:01, 3413.66it/s]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
warnings.warn(
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3081.66it/s]
--> Num of Validation Set Batches loaded = 17
Preprocessing dataset: 69%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 10213/14732 [00:02<00:01, 3396.14it/s]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
warnings.warn(
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3427.46it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3395.88it/s]
--> Num of Validation Set Batches loaded = 17
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
warnings.warn(
Training Epoch: 1: 0%| | 0/79 [00:00<?, ?it/s]
NCCL version 2.20.5+cuda12.4
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning:
torch.cpu.amp.autocast(args...)
is deprecated. Please usetorch.amp.autocast('cpu', args...)
instead.with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning:
torch.cpu.amp.autocast(args...)
is deprecated. Please usetorch.amp.autocast('cpu', args...)
instead.with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning:
torch.cpu.amp.autocast(args...)
is deprecated. Please usetorch.amp.autocast('cpu', args...)
instead.with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning:
torch.cpu.amp.autocast(args...)
is deprecated. Please usetorch.amp.autocast('cpu', args...)
instead.with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning:
torch.cpu.amp.autocast(args...)
is deprecated. Please usetorch.amp.autocast('cpu', args...)
instead.with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning:
torch.cpu.amp.autocast(args...)
is deprecated. Please usetorch.amp.autocast('cpu', args...)
instead.with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning:
torch.cpu.amp.autocast(args...)
is deprecated. Please usetorch.amp.autocast('cpu', args...)
instead.with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning:
torch.cpu.amp.autocast(args...)
is deprecated. Please usetorch.amp.autocast('cpu', args...)
instead.with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
Training Epoch: 1/3, step 0/79 completed (loss: 1.3899767398834229): 1%|███▎ | 1/79 [00:07<09:17, 7.15s/it]
max training steps reached, stopping training, total train steps finished: 1
Training Epoch: 1/3, step 0/79 completed (loss: 1.4957325458526611): 1%|███▎ | 1/79 [00:09<12:53, 9.92s/it]
Training Epoch: 1/3, step 0/79 completed (loss: 1.4859018325805664): 1%|███▎ | 1/79 [00:09<12:00, 9.23s/it]
Training Epoch: 1/3, step 0/79 completed (loss: 1.3899767398834229): 1%|███▎ | 1/79 [00:07<09:36, 7.39s/it]
Training Epoch: 1/3, step 0/79 completed (loss: 1.5577670335769653): 1%|███▎ | 1/79 [00:09<12:27, 9.59s/it]
Max CUDA memory allocated was 19 GB
Max CUDA memory reserved was 29 GB
Peak active CUDA memory was 23 GB
CUDA Malloc retries : 0
CPU Total Peak Memory consumed during the train (max): 7 GB
evaluating Epoch: 6%|██████████████████ | 1/17 [00:00<00:06, 2.37it/s]
max eval steps reached, stopping evaluation, total_eval_steps: 1
evaluating Epoch: 6%|██████████████████ | 1/17 [00:00<00:09, 1.69it/s]
evaluating Epoch: 6%|██████████████████ | 1/17 [00:00<00:08, 1.80it/s]
evaluating Epoch: 6%|██████████████████ | 1/17 [00:00<00:08, 1.79it/s]
evaluating Epoch: 6%|██████████████████ | 1/17 [00:00<00:08, 1.80it/s]
eval_ppl=tensor(2.3358, device='cuda:0') eval_epoch_loss=tensor(0.8483, device='cuda:0')
Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT
====================================================================================================================================================================================================================
Saving model to /home/mreso/llama-recipes/../llama_output_fsdp/fine-tuned-meta-llama/Meta-Llama-3.1-8B
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:737: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
local_shape = tensor.shape
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:737: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
local_shape = tensor.shape
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:737: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
local_shape = tensor.shape
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:737: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
local_shape = tensor.shape
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:749: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.shape,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:749: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.shape,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:749: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.shape,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:749: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.shape,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:751: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.dtype,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:751: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.dtype,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:751: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.dtype,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:751: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.dtype,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:752: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.device,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:752: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.device,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:752: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.device,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:752: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.device,
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:113: FutureWarning:
save_state_dict
is deprecated and will be removed in future versions.Please usesave
instead.dist_cp.save_state_dict(
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:113: FutureWarning:
save_state_dict
is deprecated and will be removed in future versions.Please usesave
instead.dist_cp.save_state_dict(
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:113: FutureWarning:
save_state_dict
is deprecated and will be removed in future versions.Please usesave
instead.dist_cp.save_state_dict(
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:113: FutureWarning:
save_state_dict
is deprecated and will be removed in future versions.Please usesave
instead.dist_cp.save_state_dict(
Sharded state checkpoint saved to /home/mreso/llama-recipes/../llama_output_fsdp/fine-tuned-meta-llama/Meta-Llama-3.1-8B
Checkpoint Time = 14.0877
best eval loss on epoch 1 is 0.8483395576477051
Epoch 1: train_perplexity=1.0189, train_epoch_loss=0.0188, epoch time 7.993172182934359s
training params are saved in /home/mreso/llama-recipes/../llama_output_fsdp/fine-tuned-meta-llama/Meta-Llama-3.1-8B/train_params.yaml
Key: avg_train_prep, Value: 1.018941044807434
Key: avg_train_loss, Value: 0.01876385696232319
Key: avg_eval_prep, Value: 2.3357651233673096
Key: avg_eval_loss, Value: 0.8483395576477051
Key: avg_epoch_time, Value: 7.993172182934359
Key: avg_checkpoint_time, Value: 14.090817171148956
``
python recipes/quickstart/finetuning/finetuning.py --model_name meta-llama/Meta-Llama-3.1-8B --output_dir ../llama_output/ --run_validation --save_model --samsum_dataset.trust_remote_code=True --context_length 2048 --max_train_step 1 --max_eval_step 1 --quantization 8bit`
Logs for Test D
Before submitting
Pull Request section?
to it if that's the case.
Thanks for contributing 🎉!