Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix checkpoint saving #650

Merged
merged 7 commits into from
Aug 30, 2024
Merged

Fix checkpoint saving #650

merged 7 commits into from
Aug 30, 2024

Conversation

mreso
Copy link
Contributor

@mreso mreso commented Aug 28, 2024

What does this PR do?

This PR

  • removes double saving of checkpoints
  • fixes a situation where a users select to fine tune a model without peft and fsdp.

Fixes # (issue)
#646

Feature/Issue validation/testing

Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.

  • Test A
    CUDA_VISIBLE_DEVICES=0,1,4,5 torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --model_name meta-llama/Meta-Llama-3.1-8B --use_peft --peft_method lora --output_dir ../llama_output/ --run_validation --save_model --samsum_dataset.trust_remote_code=True --context_length 2048 --max_train_step 1 --max_eval_step 1 cd recipes/quickstart/inference/local_inference cat samsum_prompt.txt | python inference.py --model_name meta-llama/Meta-Llama-3.1-70B-Instruct --peft_model ~/llama_output/
    Logs for Test A
Training Epoch: 1:   0%|                                                                                                                                                                                                                                                                                                                           | 0/79 [00:00<?, ?it/s]
NCCL version 2.20.5+cuda12.4
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
Training Epoch: 1/3, step 0/79 completed (loss: 1.3899767398834229):   1%|███▎                                                                                                                                                                                                                                                             | 1/79 [00:06<09:03,  6.97s/it]
max training steps reached, stopping training, total train steps finished:  1
Training Epoch: 1/3, step 0/79 completed (loss: 1.3899767398834229):   1%|███▎                                                                                                                                                                                                                                                             | 1/79 [00:07<09:10,  7.06s/it]
Training Epoch: 1/3, step 0/79 completed (loss: 1.5577670335769653):   1%|███▎                                                                                                                                                                                                                                                             | 1/79 [00:08<11:21,  8.74s/it]
Training Epoch: 1/3, step 0/79 completed (loss: 1.4859018325805664):   1%|███▎                                                                                                                                                                                                                                                             | 1/79 [00:09<11:58,  9.21s/it]
Training Epoch: 1/3, step 0/79 completed (loss: 1.4957325458526611):   1%|███▎                                                                                                                                                                                                                                                             | 1/79 [00:08<10:53,  8.38s/it]
Max CUDA memory allocated was 19 GB
Max CUDA memory reserved was 23 GB
Peak active CUDA memory was 19 GB
CUDA Malloc retries : 0
CPU Total Peak Memory consumed during the train (max): 7 GB
evaluating Epoch:   6%|██████████████████                                                                                                                                                                                                                                                                                                  | 1/17 [00:00<00:08,  1.97it/s]
max eval steps reached, stopping evaluation, total_eval_steps:  1
evaluating Epoch:   6%|██████████████████                                                                                                                                                                                                                                                                                                  | 1/17 [00:00<00:10,  1.55it/s]
evaluating Epoch:   6%|██████████████████                                                                                                                                                                                                                                                                                                  | 1/17 [00:00<00:10,  1.49it/s]
evaluating Epoch:   6%|██████████████████                                                                                                                                                                                                                                                                                                  | 1/17 [00:00<00:10,  1.49it/s]
evaluating Epoch:   6%|██████████████████                                                                                                                                                                                                                                                                                                  | 1/17 [00:00<00:11,  1.39it/s]
 eval_ppl=tensor(1.0902, device='cuda:0') eval_epoch_loss=tensor(0.0864, device='cuda:0')
we are about to save the PEFT modules
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
PEFT modules are saved in ../llama_output/ directory
best eval loss on epoch 1 is 0.08636830747127533
Epoch 1: train_perplexity=1.0189, train_epoch_loss=0.0188, epoch time 7.81007081293501s
Key: avg_train_prep, Value: 1.018941044807434
Key: avg_train_loss, Value: 0.01876385696232319
Key: avg_eval_prep, Value: 1.0902076959609985
Key: avg_eval_loss, Value: 0.08636830747127533
Key: avg_epoch_time, Value: 7.81007081293501
Key: avg_checkpoint_time, Value: 22.509945076191798
  • Test B
    CUDA_VISIBLE_DEVICES=2,3,6,7 torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --model_name meta-llama/Meta-Llama-3.1-8B --output_dir ../llama_output/ --run_validation --save_model --samsum_dataset.trust_remote_code=True --context_length 2048 --max_train_step 1 --max_eval_step 1 --fsdp_config.checkpoint_type StateDictType.FULL_STATE_DICT --dist_checkpoint_root_folder ../llama_output_fsdp/
    Logs for Test B
W0828 12:06:25.882000 139684707845120 torch/distributed/run.py:779]
W0828 12:06:25.882000 139684707845120 torch/distributed/run.py:779] *****************************************
W0828 12:06:25.882000 139684707845120 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0828 12:06:25.882000 139684707845120 torch/distributed/run.py:779] *****************************************
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead                                                                                                                                                          from torch.distributed._shard.checkpoint import (
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  8.11it/s]
--> Model meta-llama/Meta-Llama-3.1-8B

--> meta-llama/Meta-Llama-3.1-8B has 8030.261248 Million params

bFloat16 enabled for mixed precision - using bfSixteen policy
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.27it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.16it/s]Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.11it/s]
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...                                                                                                                                                                                                                                                                                                                             --> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> Training Set Length = 14732                                                                                                                                                                                                                                                                                                                                           Preprocessing dataset:   0%|                                                                                                                                                                                                                                                                                                                    | 0/14732 [00:00<?, ?it/s]
--> Validation Set Length = 818
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3482.75it/s]Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3475.14it/s]
--> Num of Validation Set Batches loaded = 17
Preprocessing dataset:  98%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏      | 14390/14732 [00:04<00:00, 3374.47it/s]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3458.10it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3460.62it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3364.39it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3252.97it/s]
--> Num of Validation Set Batches loaded = 17
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3346.98it/s]
--> Num of Validation Set Batches loaded = 17
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3292.48it/s]
--> Num of Validation Set Batches loaded = 17
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                                                                                                                                                                                                                           | 0/79 [00:00<?, ?it/s]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                                                                                                                                                                                                                           | 0/79 [00:00<?, ?it/s]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                                                                                                                                                                                                                           | 0/79 [00:00<?, ?it/s]NCCL version 2.20.5+cuda12.4
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
Training Epoch: 1/3, step 0/79 completed (loss: 1.5577670335769653):   1%|███▎                                                                                                                                                                                                                                                             | 1/79 [00:07<09:46,  7.52s/it]max training steps reached, stopping training, total train steps finished:  1
Training Epoch: 1/3, step 0/79 completed (loss: 1.3899767398834229):   1%|███▎                                                                                                                                                                                                                                                             | 1/79 [00:07<09:56,  7.65s/it]
Training Epoch: 1/3, step 0/79 completed (loss: 1.4957325458526611):   1%|███▎                                                                                                                                                                                                                                                             | 1/79 [00:07<09:49,  7.55s/it]
Training Epoch: 1/3, step 0/79 completed (loss: 1.5577670335769653):   1%|███▎                                                                                                                                                                                                                                                             | 1/79 [00:07<09:59,  7.69s/it]Training Epoch: 1/3, step 0/79 completed (loss: 1.4859018325805664):   1%|███▎                                                                                                                                                                                                                                                             | 1/79 [00:08<10:47,  8.30s/it]
Max CUDA memory allocated was 19 GB
Max CUDA memory reserved was 29 GB
Peak active CUDA memory was 20 GB
CUDA Malloc retries : 0
CPU Total Peak Memory consumed during the train (max): 7 GB
evaluating Epoch:   6%|██████████████████                                                                                                                                                                                                                                                                                                  | 1/17 [00:00<00:09,  1.70it/s]max eval steps reached, stopping evaluation, total_eval_steps:  1
evaluating Epoch:   6%|██████████████████                                                                                                                                                                                                                                                                                                  | 1/17 [00:00<00:11,  1.45it/s]
evaluating Epoch:   6%|██████████████████                                                                                                                                                                                                                                                                                                  | 1/17 [00:00<00:11,  1.38it/s]
evaluating Epoch:   6%|██████████████████                                                                                                                                                                                                                                                                                                  | 1/17 [00:00<00:11,  1.44it/s]
evaluating Epoch:   6%|██████████████████                                                                                                                                                                                                                                                                                                  | 1/17 [00:00<00:11,  1.38it/s]
 eval_ppl=tensor(2.3370, device='cuda:0') eval_epoch_loss=tensor(0.8489, device='cuda:0')
 Saving the FSDP model checkpoint using FULL_STATE_DICT Saving the FSDP model checkpoint using FULL_STATE_DICT Saving the FSDP model checkpoint using FULL_STATE_DICT

==========================================================================================================


===================================================== Saving the FSDP model checkpoint using FULL_STATE_DICT

=====================================================
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
saving process: rank 2  done w model state_dict

saving process: rank 1  done w model state_dict

saving process: rank 3  done w model state_dict

saving process: rank 0  done w model state_dict

--> saving model ...
model checkpoint saved for epoch 0 at /home/mreso/llama-recipes/../llama_output_fsdp/fine-tuned-meta-llama/Meta-Llama-3.1-8B/meta-llama--Meta-Llama-3.1-8B-0.pt

best eval loss on epoch 1 is 0.8488507866859436
Epoch 1: train_perplexity=1.0189, train_epoch_loss=0.0188, epoch time 8.528414465952665s
training params are saved in /home/mreso/llama-recipes/../llama_output_fsdp/fine-tuned-meta-llama/Meta-Llama-3.1-8B/train_params.yaml
Key: avg_train_prep, Value: 1.018941044807434
Key: avg_train_loss, Value: 0.01876385696232319
Key: avg_eval_prep, Value: 2.3369596004486084
Key: avg_eval_loss, Value: 0.8488507866859436
Key: avg_epoch_time, Value: 8.528414465952665
Key: avg_checkpoint_time, Value: 40.176256065955386
  • Test C
    CUDA_VISIBLE_DEVICES=2,3,6,7 torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --model_name meta-llama/Meta-Llama-3.1-8B --output_dir ../llama_output/ --run_validation --save_model --samsum_dataset.trust_remote_code=True --context_length 2048 --max_train_step 1 --max_eval_step 1 --fsdp_config.checkpoint_type StateDictType.SHARDED_STATE_DICT --dist_checkpoint_root_folder ../llama_output_fsdp/
    Logs for Test C
    ``
    samsum_dataset.trust_remote_code=True --context_length 2048 --max_train_step 1 --max_eval_step 1 --fsdp_config.checkpoint_type StateDictType.SHARDED_STATE_DICT --dist_checkpoint_root_folder ../llama_output_fsdp/
    W0828 12:27:59.588000 140014184924160 torch/distributed/run.py:779]
    W0828 12:27:59.588000 140014184924160 torch/distributed/run.py:779] *****************************************
    W0828 12:27:59.588000 140014184924160 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
    W0828 12:27:59.588000 140014184924160 torch/distributed/run.py:779] *****************************************
    /home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: torch.distributed._shard.checkpoint will be deprecated, use `torch.distributed.checkpoint` instead
    from torch.distributed._shard.checkpoint import (
    /home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
    from torch.distributed._shard.checkpoint import (
    /home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
    from torch.distributed._shard.checkpoint import (
    /home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
    from torch.distributed._shard.checkpoint import (
    Clearing GPU cache for all ranks
    --> Running with torch dist debug set to detail
    Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 4.23it/s]
    --> Model meta-llama/Meta-Llama-3.1-8B

--> meta-llama/Meta-Llama-3.1-8B has 8030.261248 Million params

bFloat16 enabled for mixed precision - using bfSixteen policy
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 4.11it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 4.13it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 4.03it/s]
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
Preprocessing dataset: 19%|████████████████████████████████████████████████████████▉ | 2846/14732 [00:00<00:03, 3524.81it/s]
--> Training Set Length = 14732
Preprocessing dataset: 50%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 7373/14732 [00:02<00:02, 3356.12it/s]
--> Validation Set Length = 818
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3282.43it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3324.44it/s]
--> Num of Validation Set Batches loaded = 17
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3415.49it/s]
Preprocessing dataset: 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 14219/14732 [00:04<00:00, 3262.10it/s]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
warnings.warn(
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3461.39it/s]
--> Num of Validation Set Batches loaded = 17
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3162.09it/s]
Preprocessing dataset: 62%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 9177/14732 [00:02<00:01, 3413.66it/s]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
warnings.warn(
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3081.66it/s]
--> Num of Validation Set Batches loaded = 17
Preprocessing dataset: 69%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 10213/14732 [00:02<00:01, 3396.14it/s]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
warnings.warn(
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3427.46it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3395.88it/s]
--> Num of Validation Set Batches loaded = 17
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
warnings.warn(
Training Epoch: 1: 0%| | 0/79 [00:00<?, ?it/s]
NCCL version 2.20.5+cuda12.4
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: torch.cpu.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cpu', args...) instead.
with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: torch.cpu.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cpu', args...) instead.
with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: torch.cpu.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cpu', args...) instead.
with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: torch.cpu.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cpu', args...) instead.
with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: torch.cpu.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cpu', args...) instead.
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: torch.cpu.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cpu', args...) instead.
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: torch.cpu.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cpu', args...) instead.
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: torch.cpu.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cpu', args...) instead.
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
Training Epoch: 1/3, step 0/79 completed (loss: 1.3899767398834229): 1%|███▎ | 1/79 [00:07<09:17, 7.15s/it]
max training steps reached, stopping training, total train steps finished: 1
Training Epoch: 1/3, step 0/79 completed (loss: 1.4957325458526611): 1%|███▎ | 1/79 [00:09<12:53, 9.92s/it]
Training Epoch: 1/3, step 0/79 completed (loss: 1.4859018325805664): 1%|███▎ | 1/79 [00:09<12:00, 9.23s/it]
Training Epoch: 1/3, step 0/79 completed (loss: 1.3899767398834229): 1%|███▎ | 1/79 [00:07<09:36, 7.39s/it]
Training Epoch: 1/3, step 0/79 completed (loss: 1.5577670335769653): 1%|███▎ | 1/79 [00:09<12:27, 9.59s/it]
Max CUDA memory allocated was 19 GB
Max CUDA memory reserved was 29 GB
Peak active CUDA memory was 23 GB
CUDA Malloc retries : 0
CPU Total Peak Memory consumed during the train (max): 7 GB
evaluating Epoch: 6%|██████████████████ | 1/17 [00:00<00:06, 2.37it/s]
max eval steps reached, stopping evaluation, total_eval_steps: 1
evaluating Epoch: 6%|██████████████████ | 1/17 [00:00<00:09, 1.69it/s]
evaluating Epoch: 6%|██████████████████ | 1/17 [00:00<00:08, 1.80it/s]
evaluating Epoch: 6%|██████████████████ | 1/17 [00:00<00:08, 1.79it/s]
evaluating Epoch: 6%|██████████████████ | 1/17 [00:00<00:08, 1.80it/s]
eval_ppl=tensor(2.3358, device='cuda:0') eval_epoch_loss=tensor(0.8483, device='cuda:0')
Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT

====================================================================================================================================================================================================================

Saving model to /home/mreso/llama-recipes/../llama_output_fsdp/fine-tuned-meta-llama/Meta-Llama-3.1-8B
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:737: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
local_shape = tensor.shape
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:737: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
local_shape = tensor.shape
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:737: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
local_shape = tensor.shape
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:737: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
local_shape = tensor.shape
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:749: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.shape,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:749: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.shape,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:749: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.shape,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:749: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.shape,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:751: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.dtype,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:751: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.dtype,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:751: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.dtype,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:751: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.dtype,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:752: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.device,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:752: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.device,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:752: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.device,
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py:752: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.device,
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:113: FutureWarning: save_state_dict is deprecated and will be removed in future versions.Please use save instead.
dist_cp.save_state_dict(
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:113: FutureWarning: save_state_dict is deprecated and will be removed in future versions.Please use save instead.
dist_cp.save_state_dict(
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:113: FutureWarning: save_state_dict is deprecated and will be removed in future versions.Please use save instead.
dist_cp.save_state_dict(
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:113: FutureWarning: save_state_dict is deprecated and will be removed in future versions.Please use save instead.
dist_cp.save_state_dict(
Sharded state checkpoint saved to /home/mreso/llama-recipes/../llama_output_fsdp/fine-tuned-meta-llama/Meta-Llama-3.1-8B
Checkpoint Time = 14.0877

best eval loss on epoch 1 is 0.8483395576477051
Epoch 1: train_perplexity=1.0189, train_epoch_loss=0.0188, epoch time 7.993172182934359s
training params are saved in /home/mreso/llama-recipes/../llama_output_fsdp/fine-tuned-meta-llama/Meta-Llama-3.1-8B/train_params.yaml
Key: avg_train_prep, Value: 1.018941044807434
Key: avg_train_loss, Value: 0.01876385696232319
Key: avg_eval_prep, Value: 2.3357651233673096
Key: avg_eval_loss, Value: 0.8483395576477051
Key: avg_epoch_time, Value: 7.993172182934359
Key: avg_checkpoint_time, Value: 14.090817171148956
``

  • Test D
    python recipes/quickstart/finetuning/finetuning.py --model_name meta-llama/Meta-Llama-3.1-8B --output_dir ../llama_output/ --run_validation --save_model --samsum_dataset.trust_remote_code=True --context_length 2048 --max_train_step 1 --max_eval_step 1 --quantization 8bit`
    Logs for Test D
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.19s/it]
--> Model meta-llama/Meta-Llama-3.1-8B

--> meta-llama/Meta-Llama-3.1-8B has 1050.939392 Million params

--> Training Set Length = 14732
--> Validation Set Length = 818
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3202.90it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 2500.57it/s]
--> Num of Validation Set Batches loaded = 69
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                                                                                                                                                                                                                          | 0/319 [00:00<?, ?it/s]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
Training Epoch: 1/3, step 0/319 completed (loss: 1.5929347276687622):   0%|▊                                                                                                                                                                                                                                                              | 1/319 [00:09<47:41,  9.00s/it]
max training steps reached, stopping training, total train steps finished:  1
Training Epoch: 1/3, step 0/319 completed (loss: 1.5929347276687622):   0%|▊                                                                                                                                                                                                                                                              | 1/319 [00:09<48:31,  9.16s/it]
Max CUDA memory allocated was 38 GB
Max CUDA memory reserved was 41 GB
Peak active CUDA memory was 38 GB
CUDA Malloc retries : 0
CPU Total Peak Memory consumed during the train (max): 10 GB
evaluating Epoch:   1%|████▍                                                                                                                                                                                                                                                                                                               | 1/69 [00:00<01:06,  1.02it/s]
max eval steps reached, stopping evaluation, total_eval_steps:  1
evaluating Epoch:   1%|████▍                                                                                                                                                                                                                                                                                                               | 1/69 [00:01<01:15,  1.11s/it]
 eval_ppl=tensor(1.0213, device='cuda:0') eval_epoch_loss=tensor(0.0211, device='cuda:0')
best eval loss on epoch 1 is 0.02108839526772499
Epoch 1: train_perplexity=1.0050, train_epoch_loss=0.0050, epoch time 11.862033671932295s
Key: avg_train_prep, Value: 1.0050060749053955
Key: avg_train_loss, Value: 0.0049935257993638515
Key: avg_eval_prep, Value: 1.0213123559951782
Key: avg_eval_loss, Value: 0.02108839526772499
Key: avg_epoch_time, Value: 11.862033671932295
Key: avg_checkpoint_time, Value: 18.797315332805738

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Thanks for contributing 🎉!

@mreso mreso mentioned this pull request Aug 30, 2024
7 tasks
@mreso mreso requested a review from wukaixingxp August 30, 2024 18:13
Copy link
Contributor

@wukaixingxp wukaixingxp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

@mreso mreso merged commit 778e31e into main Aug 30, 2024
3 checks passed
@mreso mreso deleted the fix/peft_qs_nb_non_fsdp_cpt branch August 30, 2024 19:08
@@ -22,7 +22,7 @@ tabulate
evaluate
rouge_score
pyyaml==6.0.1
faiss-gpu
faiss-gpu; python_version < '3.11'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why python has to be less than 3.11, this is almost 6 months old stable python version.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your change needs to be updated here in pip packages
https://pypi.org/project/llama-recipes/#history which has been updated on 23rd july. I tested fine tuning notebook and currently its broken.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants