Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

call #268

Open
wants to merge 236 commits into
base: dev
Choose a base branch
from
Open

call #268

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
236 commits
Select commit Hold shift + click to select a range
5e67140
add trtllm support
jiemingz Dec 17, 2023
d955bb4
dataloader reshard
jiemingz Dec 15, 2023
f75ae54
cleanup code and support trtllm 0.7
Feb 22, 2024
25d8aae
cleanup and sync with nemo exprot
Feb 22, 2024
5aa9dd4
remove resharding
Feb 22, 2024
ac8933c
fix topp and refactor
Feb 23, 2024
f023c7d
clamp tokens
Feb 24, 2024
46828bf
resharding PP support
Feb 29, 2024
8d50946
add mbs support
gshennvm Feb 13, 2024
0a0893a
Added nondivisible batch support on critic
Mar 7, 2024
c25cf6e
format only
gshennvm Mar 13, 2024
36f9664
log prob batch
gshennvm Mar 13, 2024
e52a35d
fix
gshennvm Mar 13, 2024
1ae9987
debug
gshennvm Mar 14, 2024
60aaecb
fix
gshennvm Mar 14, 2024
d68ed96
test
gshennvm Mar 17, 2024
b2bfa6f
dummy
gshennvm Mar 19, 2024
297255d
fix
gshennvm Mar 20, 2024
d4910b7
fix
gshennvm Mar 20, 2024
4a52afe
restore critic
gshennvm Mar 20, 2024
2140f36
fix
gshennvm Mar 21, 2024
e7ea083
add logging
gshennvm Mar 21, 2024
6cd8923
fix critic client
gshennvm Mar 21, 2024
887874d
fix
gshennvm Mar 21, 2024
6e2ca2d
fix
gshennvm Mar 21, 2024
bacd786
fix
gshennvm Mar 21, 2024
17b04ca
fix
gshennvm Mar 21, 2024
9fee58b
fix typo
gshennvm Mar 21, 2024
61ba204
better train timing
gshennvm Mar 21, 2024
c9dad2a
remove prints
gshennvm Mar 21, 2024
e5f68e2
with timing
gshennvm Mar 21, 2024
74791ae
delete unused func
gshennvm Mar 21, 2024
e40ebd6
add critic logging
gshennvm Mar 21, 2024
72ba6c6
add
gshennvm Apr 1, 2024
bfb61e4
cleanup
gshennvm Apr 6, 2024
21206c5
update
gshennvm Apr 6, 2024
148acf4
fix
gshennvm Apr 6, 2024
d7c9990
fix bug
gshennvm Apr 6, 2024
537d6e5
fix bug
gshennvm Apr 6, 2024
47400ba
test
gshennvm Apr 6, 2024
d7b2b23
fix bug
gshennvm Apr 6, 2024
ce76226
fix
gshennvm Apr 7, 2024
8edf534
add
gshennvm Apr 7, 2024
6379a2e
fix
gshennvm Apr 8, 2024
eadae31
fix again
gshennvm Apr 8, 2024
e2b97d9
fix
gshennvm Apr 8, 2024
d9bdf7c
fix mean
gshennvm Apr 8, 2024
1c7d215
fix
gshennvm Apr 8, 2024
3638301
add debug
gshennvm Apr 8, 2024
4cca85f
fix
gshennvm Apr 8, 2024
1b19bdd
add data iter for VP
gshennvm Apr 8, 2024
3f045ae
move
gshennvm Apr 8, 2024
3c9fe3d
fixing
gshennvm Apr 8, 2024
f36f394
add
gshennvm Apr 8, 2024
5211bc2
chunking needs to be moved out
gshennvm Apr 8, 2024
0f59edf
fix
gshennvm Apr 8, 2024
c3fe2f7
fix metrics
gshennvm Apr 9, 2024
5d3e07d
fix dtype
gshennvm Apr 9, 2024
15887e5
merge
gshennvm Apr 9, 2024
2ad76ba
fix
gshennvm Apr 9, 2024
9d9a6b6
make the global id management into a class
gshennvm Apr 10, 2024
d6fb55d
fix
gshennvm Apr 11, 2024
0983164
trtllm0.9 changes (#149)
jiemingz Apr 17, 2024
fe765cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
dfac922
trtllm patch file
jiemingz Apr 17, 2024
c159aa3
dockerfile
jiemingz Apr 17, 2024
d81caef
fix build
gshennvm Apr 18, 2024
92c19f6
fix bug
gshennvm Apr 18, 2024
7088f54
add groupnorm build
jiemingz Apr 19, 2024
472a56c
upgrade to latest te and mcore
gshennvm Apr 18, 2024
032bf35
Merge remote-tracking branch 'origin/dev' into aligner_trt_build
gshennvm Apr 19, 2024
d5f55f5
fix
gshennvm Apr 19, 2024
c7cdca1
specify max token
Apr 20, 2024
04d02c8
fix
gshennvm Apr 20, 2024
56ccacf
Merge remote-tracking branch 'origin/geshen/main_trt' into aligner_tr…
gshennvm Apr 20, 2024
d23865f
fix critic checkpoint loading
gshennvm Apr 22, 2024
3c21c81
add assert
gshennvm Apr 22, 2024
2c99dcb
fix bug
gshennvm Apr 22, 2024
9e8526d
fix
gshennvm Apr 22, 2024
c1daeb9
fix
gshennvm Apr 23, 2024
e16c357
update dockerfile
gshennvm Apr 23, 2024
410eaf5
update to 24.03.01 deps
gshennvm Apr 23, 2024
e405432
fix
gshennvm Apr 24, 2024
07cfa67
update dockerfile
gshennvm Apr 24, 2024
b2dfee0
add dockerfileg
gshennvm Apr 26, 2024
63cd6b3
fix trtllm patch
Apr 29, 2024
6901348
clamp output with warning
Apr 29, 2024
74a0bb1
fix
gshennvm Apr 29, 2024
b6a05fd
remove debug statements
gshennvm Apr 30, 2024
db2701b
Merge remote-tracking branch 'origin/main' into aligner_trt_build
gshennvm Apr 30, 2024
8dd5c59
add debug info
gshennvm May 6, 2024
b5d6f88
bump pytrition version
gshennvm May 6, 2024
5464827
add critic speed
gshennvm May 7, 2024
00e4298
critic speedup
gshennvm May 7, 2024
fe6864b
Merge remote-tracking branch 'origin/geshen/critic_refactor' into ges…
gshennvm May 7, 2024
80579ec
fix
gshennvm May 13, 2024
f81f55a
add pad sequence length
gshennvm May 14, 2024
4a034f4
dockerfile
gshennvm May 15, 2024
66b5a54
higher stability
gshennvm May 16, 2024
7841381
Merge remote-tracking branch 'origin/main' into geshen/debug_critic
gshennvm May 16, 2024
1779c51
add
gshennvm May 16, 2024
e357ef9
add hack for ckpt
gshennvm May 24, 2024
6b606e8
fix conf
gshennvm May 24, 2024
a669837
no import
gshennvm May 24, 2024
ada5f45
add
gshennvm May 24, 2024
393acc6
fix
gshennvm May 25, 2024
b6a4d59
run through
gshennvm May 25, 2024
977e6e7
fix
gshennvm May 25, 2024
621718d
adaptive
gshennvm May 25, 2024
6109b8b
output tensor
gshennvm May 25, 2024
866c22b
add logging
gshennvm May 26, 2024
02aa2b8
fix for llama3
gshennvm May 28, 2024
e6f27c5
disable last checkpoint
gshennvm May 31, 2024
c689d2a
fix padding bug
gshennvm Jun 1, 2024
cd4aaa5
add critic warmup
gshennvm Jun 1, 2024
993e358
Revert "add"
gshennvm Jun 8, 2024
28fcaf3
Merge remote-tracking branch 'origin/main' into geshen/trt_llm_to_main
gshennvm Jun 8, 2024
ef347e5
fix module missing bug
gshennvm Jun 11, 2024
752d0bd
Ensure critic server does not squeeze out a singleton batch dim (#199)
terrykong Jun 11, 2024
78e6536
Merge branch 'geshen/llama3_rlhf' into geshen/trt_llm_to_main
gshennvm Jun 12, 2024
4de3eeb
Merge branch 'geshen/trt_llm_to_main' of github.com:NVIDIA/NeMo-Align…
gshennvm Jun 12, 2024
8a39881
TRTLLM PP wrong format WAR
jiemingz May 17, 2024
666e969
docker file branch
gshennvm Jun 12, 2024
3bec1bc
fix config
gshennvm Jun 12, 2024
3e7ca5f
remove prints
gshennvm Jun 12, 2024
12a0aae
remove print
gshennvm Jun 12, 2024
3956b6d
remove unneeded statement
gshennvm Jun 12, 2024
e090663
no save topk
gshennvm Jun 12, 2024
af83947
Merge remote-tracking branch 'origin/main' into geshen/trt_llm_to_main
gshennvm Jun 24, 2024
cc03b76
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2024
605bda1
critic speedup
gshennvm Jun 24, 2024
b3dedfd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2024
9fb90ff
fix
gshennvm Jun 24, 2024
bf62bcc
Merge branch 'geshen/critic_speedup' of github.com:NVIDIA/NeMo-Aligne…
gshennvm Jun 24, 2024
aea50ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2024
0a9416e
fix
gshennvm Jun 24, 2024
41ffeb5
Merge branch 'geshen/critic_speedup' of github.com:NVIDIA/NeMo-Aligne…
gshennvm Jun 24, 2024
e3d85bf
add parallel_state
gshennvm Jun 24, 2024
1225041
fix
gshennvm Jun 24, 2024
5315fc5
rename
gshennvm Jun 24, 2024
9c6141b
fix
gshennvm Jun 24, 2024
2cef93e
fix
gshennvm Jun 24, 2024
f5bd8c5
pull changes from degert/spin-trt-beta (#220)
gshennvm Jun 25, 2024
cbd7095
add text input support
gshennvm Jun 25, 2024
c72e4c1
clean up location of tokenization
gshennvm Jun 25, 2024
220cf17
remove useless imports
gshennvm Jun 25, 2024
1ed3c46
refactor rm server
gshennvm Jun 26, 2024
aa4249d
remove run rm file
gshennvm Jun 26, 2024
22615b7
remove preferred batch size logic
gshennvm Jun 26, 2024
95a9507
add comment
gshennvm Jun 26, 2024
0fdb124
allow users to specify their own preferred batch size
gshennvm Jun 26, 2024
031287c
add comments and add changelog
gshennvm Jun 26, 2024
fa37930
update changelog
gshennvm Jun 26, 2024
c83a98d
remove old reward model callable
gshennvm Jun 26, 2024
a9f722a
inference should be done with collect loss data =True
gshennvm Jun 26, 2024
15460d2
Merge branch 'geshen/critic_speedup' into geshen/trt_llm_to_main
gshennvm Jun 27, 2024
f6c09ea
fix issues with merge
gshennvm Jun 27, 2024
1d85152
cleanup configs
gshennvm Jun 27, 2024
b3078c5
add strip to sequence length
gshennvm Jun 27, 2024
1eb0823
Merge remote-tracking branch 'origin/geshen/critic_speedup' into gesh…
gshennvm Jun 27, 2024
6203e90
change
gshennvm Jun 27, 2024
e51c45f
fix
gshennvm Jun 27, 2024
1b220f1
clean actor
gshennvm Jun 28, 2024
0815b56
backwards compatibility in actor
gshennvm Jun 28, 2024
8d75cbf
Apply suggestions from code review
gshennvm Jun 28, 2024
ab4c549
modify changelog
gshennvm Jun 28, 2024
82fb3a1
fixup! modify changelog
gshennvm Jun 28, 2024
12f85d2
add comments to ppo_critic config
gshennvm Jun 28, 2024
f90f4d6
add note on breaking change in inference rm
gshennvm Jun 28, 2024
527557d
change inference mbs to 4
gshennvm Jun 28, 2024
fe9d288
add comments for inference rm config
gshennvm Jun 28, 2024
f3124d3
revert gbs flag back to previous in ppo critic
gshennvm Jun 28, 2024
efadcae
delete unused variable
gshennvm Jun 28, 2024
35a2895
Update nemo_aligner/algorithms/critic_server_trainer.py
gshennvm Jun 28, 2024
4487932
remove add_eos arg, and update attribute annotate script
gshennvm Jun 29, 2024
7e7f27b
Merge branch 'geshen/critic_speedup' of github.com:NVIDIA/NeMo-Aligne…
gshennvm Jun 29, 2024
e9c7b39
no mutation on inputs when processing them for inference
gshennvm Jun 29, 2024
c6f6da4
fix bug when padding
gshennvm Jun 29, 2024
ebb69f4
add comment for forward_micro_batch_size in training_rm.yaml
gshennvm Jun 29, 2024
2775e81
change non_blocking to use sync
gshennvm Jun 29, 2024
fe0399f
Merge branch 'geshen/critic_speedup' into geshen/trt_llm_to_main
gshennvm Jun 30, 2024
ffa253f
nemo export api changes
jiemingz Jul 1, 2024
7ca9e34
upgrade to newer nemo export
gshennvm Jul 1, 2024
8181168
fix imports
gshennvm Jul 1, 2024
4d0853d
Communicator hang fix in the actor loop (#200)
terrykong Jul 1, 2024
ec548b8
add nemo guard for when things don't stop properly
gshennvm Jul 3, 2024
ce7a07f
cleanup communicator clean
gshennvm Jul 3, 2024
bb2fc48
fix
gshennvm Jul 3, 2024
606f690
critic speedup
gshennvm Jun 24, 2024
f48dc29
middle of PP should be broadcasted as well
gshennvm Jul 11, 2024
708bc24
update with critic changes
gshennvm Jul 11, 2024
48ad685
Merge remote-tracking branch 'origin/main' into geshen/trt_llm_to_main
gshennvm Jul 11, 2024
d053475
general cleanup
gshennvm Jul 11, 2024
0b4a92d
add checker for trt
gshennvm Jul 11, 2024
b72a5ec
remove comments
gshennvm Jul 11, 2024
984acaa
fix
gshennvm Jul 12, 2024
c11e1d7
fix
gshennvm Jul 12, 2024
14a9926
another fix
gshennvm Jul 12, 2024
7c2fc3e
add typing
gshennvm Jul 12, 2024
fe02867
cleanup
gshennvm Jul 12, 2024
02ad2fa
ppo trainer should use stop and get time
gshennvm Jul 12, 2024
24c53be
add some comments
gshennvm Jul 12, 2024
8a25e5e
critic warmup should have good default
gshennvm Jul 12, 2024
24f138a
added ppo in changelog
gshennvm Jul 12, 2024
9c72c53
add comments
gshennvm Jul 12, 2024
5ed9cd8
Avoids crash in PPOTrainer if using adafactor w/o learning rate (#234)
terrykong Jul 12, 2024
8b6627a
rename
gshennvm Jul 12, 2024
56032c8
Merge branch 'geshen/trt_llm_to_main' of github.com:NVIDIA/NeMo-Align…
gshennvm Jul 12, 2024
1e17f8b
Raise exceptions if using trtllm and use_Greedy in sampling params is…
terrykong Jul 12, 2024
e0a94d0
fix bugs
gshennvm Jul 12, 2024
835b3b3
cleanup pad id handling when PP > 1
gshennvm Jul 13, 2024
2b95331
fix issue with PP > 1 check
gshennvm Jul 14, 2024
261269a
add is_end logic
gshennvm Jul 14, 2024
83ba660
add is end logic
gshennvm Jul 14, 2024
f3912e7
add is end logic
gshennvm Jul 14, 2024
09d2783
fix
gshennvm Jul 14, 2024
5105ed9
fix
gshennvm Jul 14, 2024
a2bf8a0
fix another bug
gshennvm Jul 15, 2024
d9d45d6
Merge remote-tracking branch 'origin/main' into geshen/trt_llm_to_main
gshennvm Jul 15, 2024
f00d09e
update changelog
gshennvm Jul 15, 2024
d52eab2
change rc version
gshennvm Jul 15, 2024
12fa459
Update the hash of the conversion script to include TE fix for mcore …
terrykong Jul 16, 2024
3d0a650
update docs
gshennvm Jul 16, 2024
e103759
Addresses documentation bugs (#244)
terrykong Jul 22, 2024
db3a9cc
Updates microbatch APIs to use megatron's instead of apex's (#241)
terrykong Jul 23, 2024
7dddeef
add pref
gshennvm Jul 31, 2024
cc26456
fix HF links in readme (#248)
gshennvm Jul 23, 2024
44701e5
Point the TRTLLM documentation to Nemo docs (#257)
terrykong Aug 2, 2024
a1f99fd
add scaling efficiency instead
gshennvm Aug 2, 2024
e741e76
fix ending
gshennvm Aug 2, 2024
e4ff1ec
fix logic
gshennvm Aug 2, 2024
46afc6b
0.4 doc tech edit (#260)
terrykong Aug 5, 2024
9d5c71f
Restore missing newline in performance table
terrykong Aug 5, 2024
928d4a9
fix rlhf perf table column headings
terrykong Aug 6, 2024
384394f
Updates MBS calculator APIs to reconfigure_num_microbatches_calculator
terrykong Aug 6, 2024
7eb9a97
call
gshennvm Aug 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 39 additions & 37 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,75 +3,77 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Next Version]
## [0.4.0]
- Implement reward-aware preference optimization.
- Added TRT-LLM support in PPO. This can be enabled by doing `trainer.ppo.trt_llm.enable=True`. There is also a reshard option to reshard out pipeline parallelism during inference for further speedup via `trainer.ppo.trt_llm.reshard=True`.
- PPO algorithm will now detect if the sample sequence is ended, and if so zero out the gradient of the samples that did not stop properly.
- Added critic warmup to the PPO with the flag trainer.ppo.critic_warmup_steps.

### New features and optimizations
- Critic and Reward Model server refactored. Now the reward model will have a flag called `model.forward_micro_batch_size` which determines the micro batch size that it runs inferences with. This can be higher than the training micro batch size since during inference we have less memory pressure.
- In the critic and reward model server it is now possible to specify `inference_micro_batch_size` as a list, this allows us to give more information to PyTriton on the preferred batch sizes we want to run inference with.
### New Features and Optimizations
- Critic and Reward Model server refactored. Now the reward model will have a flag called `model.forward_micro_batch_size` which determines the micro batch size on which it runs inferences. This can be higher than the training micro batch size since during inference, we have less memory pressure.
- In the critic and reward model server, it is now possible to specify `inference_micro_batch_size` as a list. This allows us to provide more information to PyTriton regarding the preferred batch sizes for inference.
- It is no longer a requirement to specify `num_rollout_samples` to be a multiple of `inference_micro_batch_size * dp size` in PPO.

### Breaking changes
- `inference.micro_batch_size` is now renamed to `inference.inference_micro_batch_size` when running reward model inference in `inference_rm.yaml` this is to stay consistent with the naming scheme of the PPO critic.
### Breaking Changes
- `inference.micro_batch_size` is now renamed to `inference.inference_micro_batch_size` when running reward model inference in `inference_rm.yaml`. This is to stay consistent with the naming scheme of the PPO critic.
- It is no longer possible to specify `add_EOS` when running reward model or critic inference.
- NeMo-Aligner now requires Megatron-LM>=0.8.0 for the APIs to calculate the microbatch sizes.

### Bug Fixes
- Make `num_workers` for dataloaders 0 by default. This prevents issues when using MPI (with TRT-LLM) or more sophisticated launchers.

## [0.3.1] - 2024-05
- SPIN: added `rollout_micro_batch_size` parameter which allows users to set the batch size for doing generation during SPIN training.
previously the generation batch size was automatically set to the data parallel size (DP) of the model
- SPIN: added wandb logging of average generation length and a small sample of generated responses (in plaintext) along with corresponding prompts
- SPIN: added `rollout_micro_batch_size` parameter which allows users to set the batch size for doing generation during SPIN training. Previously, the generation batch size was automatically set to the data parallel size (DP) of the model.
- SPIN: added wandb logging of average generation length and a small sample of generated responses (in plaintext) along with their corresponding prompts.

### New features and optimizations
### New Features and Optimizations
- Add MoE Support for our reward models.
- SFT/SteerLM: LoRA can now be enabled on all model layers
- DPO: Enable LoRA on all model layers (In this case the actor will be reference model + LoRA weights, we can switch between actor/reference model by enabling/disabling LoRA)
- PPO: Enable LoRA on all model layers (In this case the actor will be init policy + LoRA weights, we can switch between actor/init_policy model by enabling/disabling LoRA)
- SFT/SteerLM: LoRA can now be enabled on all model layers.
- DPO: Enable LoRA on all model layers. In this case, the actor will be a reference model plus LoRA weights. We can switch between the actor/reference model by enabling or disabling LoRA.
- PPO: Enable LoRA on all model layers. In this case, the actor will be the init policy plus LoRA weights. We can switch between the actor/init_policy model by enabling or disabling LoRA.
- SteerLM 2.0: Add the SteerLM 2.0 model alignment method.
- Added support for float values for `val_check_interval` for SFT
- Added support for `limit_train_batches` as a float or int to DPO, SPIN, and SFT. This functionality mirrors the same parameter in PTL
### Breaking changes
- `val_check_interval` in SFT now supports float values.
- Added support for `limit_train_batches` as a float or int to DPO, SPIN, and SFT. This functionality mirrors the same parameter in PTL.

### Breaking Changes

### Bug Fixes
- Fixed issue where random sampler keeps state when resetting for validation, leading to a different validation batch each validation step. Fixed by using a deterministic sampler
- Fixed crash with float val check interval in DPOTrainer
- Fixed crash with float val check interval when checking progress in DPOTrainer
- Fixed potential crash in SPIN when prompts are longer than encoder_seq_len - generation.max_length
- Fixed crash when calling the `generate()` method of an SFT model with pipeline parallelism greater than two
- Fixed crash when calling the `generate()` method of an SFT model with `compute_logprob=True` and string inputs
- Fixed crash when `model.micro_batch_size` > 1 in DPO
- Fixed issue where the random sampler keeps its state during validation resets, resulting in varying validation batches at each step. This was addressed by switching to a deterministic sampler.
- Fixed crash with float val check interval in DPOTrainer.
- Fixed crash with float val check interval when checking progress in DPOTrainer.
- Fixed potential crash in SPIN when prompts are longer than encoder_seq_len - generation.max_length.
- Fixed crash when calling the `generate()` method of an SFT model with pipeline parallelism greater than two.
- Fixed crash when calling the `generate()` method of an SFT model with `compute_logprob=True` and string inputs.
- Fixed crash when `model.micro_batch_size` > 1 in DPO.
- Fixed issue when `model.encoder_seq_length` is mismatched with `model.data.train_ds.max_seq_length` in SFT and SPIN.
- Delete MegatronPretrainingRandomSampler from Aligner since it has been upstreamed into NeMo
- Fixed SPIN not correctly using its `val_check_interval` parameter
- Delete MegatronPretrainingRandomSampler from NeMo-Aligner since it has been upstreamed into NeMo.
- Fixed SPIN not correctly using its `val_check_interval` parameter.

## [0.3.0] - 2024-05

### New features and optimizations
### New Features and Optimizations
- Special TRT-LLM release. See [Accelerated-RLHF](https://github.com/NVIDIA/NeMo-Aligner/blob/v0.3.0.trtllm/Accelerated-RLHF.md) and [Accelerated-RLHF-Release](https://github.com/NVIDIA/NeMo-Aligner/releases/tag/v0.3.0.trtllm) for more details.

## [0.2.0] - 2024-02
### New features and optimizations
### New Features and Optimizations
- Added public-facing official Dockerfile for NeMo-Aligner.
- PPO: memory optimization to help avoid OOM in the actor when sending training data to the critic.
- PPO: it is now possible to use a custom end string in `sampling_params.end_strings` that is different from `<extra_id_1>`.
- SFT: added support for custom validation metrics based on model generations.
- Added the ability to do multi-epoch (cfg.max_epochs > 1) training for reward models, DPO, PPO, and SFT
- Added the SPIN (Self-Play Fine Tuning) algorithm (https://arxiv.org/abs/2401.01335) which allows SPIN SFT training using SFT-format dataset files
- SFT/SteerLM: added LoRA tuning as an option besides full fine-tuning, only attention_qkv layer is supported
- Added the ability to do multi-epoch (cfg.max_epochs > 1) training for reward models, DPO, PPO, and SFT.
- Added the SPIN (Self-Play Fine Tuning) algorithm (https://arxiv.org/abs/2401.01335) which allows SPIN SFT training using SFT-format dataset files.
- SFT/SteerLM: added LoRA tuning as an option besides full fine-tuning, only attention_qkv layer is supported.

### Breaking changes
- We have changed the shuffle logic in the data sampler to support multi-epoch training, so training runs using identical parameters
will not give the same results anymore because the shuffle logic has changed (specifically the seed value is modified slightly per epoch).
If you run CI/regression type tests, then be warned that the test may break due to this shuffle change.
### Breaking Changes
- We have changed the shuffle logic in the data sampler to support multi-epoch training, so training runs using identical parameters. It will no longer give the same results because the shuffle logic has changed (specifically the seed value is modified slightly per epoch). If you run CI/regression type tests, be warned that the test may break due to this shuffle change.

### Bug Fixes
- Fixed a potential issue when the base model's `model.data.data_prefix` config is a list and is about to be overridden with
a dictionary from the training configuration.
- `exp_manager.max_time_per_run` is now respected, the trainers will save and run validation before exiting if we've reached the time limit.
- `exp_manager.max_time_per_run` is now respected. The trainers will save and run the validation before exiting if the time limit has been reached.
- Fixed crash in PPO when using a separate reward model server (i.e., with `combine_rm_and_critic_server=False`).
- Fixed crash when LR scheduler is not specified
- Fixed crash when LR scheduler is not specified.

## [0.1.0] - 2023-12-04
### Added
- First open source release
- First open source release.
25 changes: 23 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
# CUDA 12.3
FROM nvcr.io/nvidia/pytorch:24.02-py3
FROM nvcr.io/nvidia/pytorch:24.02-py3

### config tags
ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
ARG TE_TAG=a51ff542dcb1f605aa54f9b0e1aaadb132acd53d
ARG MLM_TAG=core_r0.7.0
ARG NEMO_TAG=r2.0.0rc0
ARG PYTRITON_VERSION=0.5.5
ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
ARG TE_TAG=bfe21c3d68b0a9951e5716fb520045db53419c5e
ARG MLM_TAG=fbb375d4b5e88ce52f5f7125053068caff47f93f
ARG NEMO_TAG=1ff3a061da9751e4d645c8de66c0dfd27bd5d119
ARG PYTRITON_VERSION=0.5.5
ARG PROTOBUF_VERSION=4.24.4
ARG ALIGNER_COMMIT=main

Expand Down Expand Up @@ -37,7 +43,7 @@ RUN pip uninstall -y apex && \
git fetch origin $APEX_TAG && \
git checkout FETCH_HEAD; \
fi && \
pip install install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./
pip install -e . -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam --group_norm"

# place any util pkgs here
RUN pip install --upgrade-strategy only-if-needed nvidia-pytriton==$PYTRITON_VERSION
Expand Down Expand Up @@ -77,4 +83,19 @@ RUN git clone https://github.com/NVIDIA/NeMo-Aligner.git && \
fi && \
pip install --no-deps -e .

WORKDIR /workspace
# Git LFS
RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \
apt-get install git-lfs && \
git lfs install

# TRTLLM-0.9
RUN git clone https://github.com/NVIDIA/TensorRT-LLM.git && \
cd TensorRT-LLM && \
git checkout v0.9.0 && \
git apply ../NeMo-Aligner/trtllm.patch && \
. docker/common/install_tensorrt.sh && \
python3 ./scripts/build_wheel.py --trt_root /usr/local/tensorrt

RUN cd TensorRT-LLM && \
pip install ./build/tensorrt_llm*.whl
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.3/compat/lib.real/
Loading
Loading