diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index d2d27e95a..4ec29dba8 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -95,6 +95,8 @@ jobs: - kd-llama3 - sft-llama3 - rm-llama3 + - dpo-mixtral-ep + - dpo-mixtral-peft-tp-sp with: RUNNER: self-hosted-azure # Fairly aggresive timeout that all functional tests should try to adhere to diff --git a/Dockerfile b/Dockerfile index 4d22c6a21..2ad368dda 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,8 +13,8 @@ ARG MAX_JOBS=8 # Git refs for dependencies ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea ARG PYTRITON_VERSION=0.5.10 -ARG NEMO_TAG=19668e5320a2e2af0199b6d5e0b841993be3a634 # On: main -ARG MLM_TAG=25059d3bbf68be0751800f3644731df12a88f3f3 # On: main +ARG NEMO_TAG=06eae2895c0fea09f8dd7c34feff0163e55c419a # On: main +ARG MLM_TAG=844119f5c856a3037ec7c7f6d6ef7b3518ceee6b # On: main ARG ALIGNER_COMMIT=main ARG TRTLLM_VERSION=v0.13.0 ARG PROTOBUF_VERSION=4.24.4 @@ -123,19 +123,19 @@ RUN cd /opt/NeMo-Aligner && \ RUN cd TensorRT-LLM && patch -p1 < ../NeMo-Aligner/setup/trtllm.patch -# TODO(terryk): This layer should be deleted ASAP after NeMo is bumped to include all of these PRs +# NOTE: Comment this layer out if it is not needed +# NOTE: This section exists to allow cherry-picking PRs in cases where +# we do not wish to simply update to the top-of-tree. Sometimes PRs +# cannot be cherry-picked cleanly if rebased a few times to top-of-tree +# so this logic also requires you to select a SHA (can be dangling) from +# the PR. RUN <<"EOF" bash -exu cd NeMo # Ensures we don't cherry-pick "future" origin/main commits git fetch -a -# 0c92fe17df4642ffc33d5d8c0c83fda729e3910c: [fix] Ensures disabling exp_manager with exp_manager=null does not error NeMo#10651 -# 60e677423667c029dd05875da72bf0719774f844: [feat] Update get_model_parallel_src_rank to support tp-pp-dp ordering NeMo#10652 -# 0deaf6716cb4f20766c995ce25d129795f1ae200: fix[export]: update API for disabling device reassignment in TRTLLM for Aligner NeMo#10863 -# (superceded by 10863) 148543d6e9c66ff1f8562e84484448202249811d: feat: Migrate GPTSession refit path in Nemo export to ModelRunner for Aligner NeMo#10654 +# d27dd28b4186f6ecd9f46f1c5679a5eef9bad14e: fix: export weight name mapping if model is nemo model#11497 for pr_and_commit in \ - "10651 0c92fe17df4642ffc33d5d8c0c83fda729e3910c" \ - "10652 60e677423667c029dd05875da72bf0719774f844" \ - "10863 0deaf6716cb4f20766c995ce25d129795f1ae200" \ + "11497 d27dd28b4186f6ecd9f46f1c5679a5eef9bad14e" \ ; do pr=$(cut -f1 -d' ' <<<"$pr_and_commit") head_pr_commit=$(cut -f2 -d' ' <<<"$pr_and_commit") diff --git a/examples/nlp/gpt/conf/gpt_dpo.yaml b/examples/nlp/gpt/conf/gpt_dpo.yaml index 192265244..3b589a3fe 100644 --- a/examples/nlp/gpt/conf/gpt_dpo.yaml +++ b/examples/nlp/gpt/conf/gpt_dpo.yaml @@ -6,6 +6,7 @@ trainer: devices: 8 accelerator: gpu precision: bf16 + gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value # dpo specific args dpo: @@ -17,6 +18,7 @@ trainer: # how many GBS we loop over limit_val_batches: 1.0 + # TODO: delete once Megatron Core optimizer becomes default gradient_clip_val: 1.0 # do not change these diff --git a/examples/nlp/gpt/conf/gpt_kto.yaml b/examples/nlp/gpt/conf/gpt_kto.yaml index de264056a..f6cd60059 100644 --- a/examples/nlp/gpt/conf/gpt_kto.yaml +++ b/examples/nlp/gpt/conf/gpt_kto.yaml @@ -6,6 +6,7 @@ trainer: devices: 8 accelerator: gpu precision: bf16 + gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value # kto specific args kto: @@ -17,6 +18,7 @@ trainer: # how many GBS we loop over limit_val_batches: 1.0 + # TODO: delete once Megatron Core optimizer becomes default gradient_clip_val: 1.0 # do not change these diff --git a/examples/nlp/gpt/conf/gpt_ppo_actor.yaml b/examples/nlp/gpt/conf/gpt_ppo_actor.yaml index e0a5a1045..22b899e50 100644 --- a/examples/nlp/gpt/conf/gpt_ppo_actor.yaml +++ b/examples/nlp/gpt/conf/gpt_ppo_actor.yaml @@ -7,6 +7,7 @@ trainer: devices: 8 accelerator: gpu precision: bf16 + gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value ppo: # How many steps we train warmup the critic for (without training the policy) @@ -21,6 +22,7 @@ trainer: max_steps: -1 # max PPO steps (-1 to go through the whole train set) val_check_interval: 10 save_interval: ${.val_check_interval} + # TODO: delete once Megatron Core optimizer becomes default gradient_clip_val: 1.0 # PPO args to generate the data for training diff --git a/examples/nlp/gpt/conf/gpt_ppo_critic.yaml b/examples/nlp/gpt/conf/gpt_ppo_critic.yaml index 75974767f..8e146eb8c 100644 --- a/examples/nlp/gpt/conf/gpt_ppo_critic.yaml +++ b/examples/nlp/gpt/conf/gpt_ppo_critic.yaml @@ -6,6 +6,7 @@ trainer: devices: 8 accelerator: gpu precision: bf16 + gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value ppo: port: 5556 @@ -15,6 +16,7 @@ trainer: # used to set the learning rate scheduler max_steps: 10000 + # TODO: delete once Megatron Core optimizer becomes default gradient_clip_val: 1.0 # a PyTriton parameter to specify diff --git a/examples/nlp/gpt/conf/gpt_rs_actor.yaml b/examples/nlp/gpt/conf/gpt_rs_actor.yaml index b819ca287..6ff1a228a 100644 --- a/examples/nlp/gpt/conf/gpt_rs_actor.yaml +++ b/examples/nlp/gpt/conf/gpt_rs_actor.yaml @@ -7,12 +7,14 @@ trainer: devices: 8 accelerator: gpu precision: bf16 + gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value rs: max_epochs: 1 max_steps: -1 # max rs steps (-1 to go through the whole train set) val_check_interval: 10 save_interval: ${.val_check_interval} + # TODO: delete once Megatron Core optimizer becomes default gradient_clip_val: 1.0 # pick up from the model @@ -177,4 +179,4 @@ model: # define fields from the base model's config that should be ignored when merging with this config. overwrite_base_config: data: - data_prefix: True \ No newline at end of file + data_prefix: True diff --git a/examples/nlp/gpt/conf/gpt_sft.yaml b/examples/nlp/gpt/conf/gpt_sft.yaml index 9946b094a..ec56d8698 100644 --- a/examples/nlp/gpt/conf/gpt_sft.yaml +++ b/examples/nlp/gpt/conf/gpt_sft.yaml @@ -5,6 +5,7 @@ trainer: devices: 1 accelerator: gpu precision: bf16 + gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value sft: max_epochs: 1 @@ -15,6 +16,7 @@ trainer: limit_train_batches: 1.0 limit_val_batches: 1.0 + # TODO: delete once Megatron Core optimizer becomes default gradient_clip_val: 1.0 # can be used to register any custom metrics that require token-by-token generation diff --git a/examples/nlp/gpt/conf/gpt_spin.yaml b/examples/nlp/gpt/conf/gpt_spin.yaml index 4027dbf8e..96772d975 100644 --- a/examples/nlp/gpt/conf/gpt_spin.yaml +++ b/examples/nlp/gpt/conf/gpt_spin.yaml @@ -6,6 +6,7 @@ trainer: devices: 8 accelerator: gpu precision: bf16-mixed + gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value # spin specific args spin: @@ -18,6 +19,7 @@ trainer: # how many GBS we loop over limit_val_batches: 1.0 + # TODO: delete once Megatron Core optimizer becomes default gradient_clip_val: 1.0 # do not change these diff --git a/examples/nlp/gpt/conf/training_rm.yaml b/examples/nlp/gpt/conf/training_rm.yaml index afe927423..77a2ba09c 100644 --- a/examples/nlp/gpt/conf/training_rm.yaml +++ b/examples/nlp/gpt/conf/training_rm.yaml @@ -6,6 +6,7 @@ trainer: devices: 8 accelerator: gpu precision: bf16 + gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value # rm specific args rm: @@ -20,6 +21,7 @@ trainer: # set to float for a percentage # of the validation dataset limit_val_batches: 1.0 + # TODO: delete once Megatron Core optimizer becomes default gradient_clip_val: 1.0 # do not change these diff --git a/nemo_aligner/algorithms/critic_server_trainer.py b/nemo_aligner/algorithms/critic_server_trainer.py index d3a7e0d8c..ff91214ac 100644 --- a/nemo_aligner/algorithms/critic_server_trainer.py +++ b/nemo_aligner/algorithms/critic_server_trainer.py @@ -322,7 +322,7 @@ def run_training(self, tokens=None, returns=None, prev_values=None, mask=None): grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm lr = self.optimizer.param_groups[0]["lr"] - self.optimizer.step() + self.optimizer.step(closure=None) self.scheduler.step() if grad_norm is not None: diff --git a/nemo_aligner/algorithms/dpo.py b/nemo_aligner/algorithms/dpo.py index 626b7b58e..75b773106 100644 --- a/nemo_aligner/algorithms/dpo.py +++ b/nemo_aligner/algorithms/dpo.py @@ -220,7 +220,7 @@ def train_single_step(self, global_batch): grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm lr = self.optimizer.param_groups[0]["lr"] - self.optimizer.step() + self.optimizer.step(closure=None) self.scheduler.step() trainer_metrics = {} diff --git a/nemo_aligner/algorithms/ppo.py b/nemo_aligner/algorithms/ppo.py index 323c18224..3851cc4b1 100644 --- a/nemo_aligner/algorithms/ppo.py +++ b/nemo_aligner/algorithms/ppo.py @@ -440,7 +440,7 @@ def run_training(self, dataloader_iter): grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm lr = self.optimizer.param_groups[0]["lr"] - self.optimizer.step() + self.optimizer.step(closure=None) self.scheduler.step() if grad_norm is not None: diff --git a/nemo_aligner/algorithms/rs.py b/nemo_aligner/algorithms/rs.py index 493b743d4..11bb7b141 100644 --- a/nemo_aligner/algorithms/rs.py +++ b/nemo_aligner/algorithms/rs.py @@ -294,7 +294,7 @@ def run_training(self, dataloader_iter): grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm lr = self.optimizer.param_groups[0]["lr"] - self.optimizer.step() + self.optimizer.step(closure=None) self.scheduler.step() if grad_norm is not None: diff --git a/nemo_aligner/algorithms/spin.py b/nemo_aligner/algorithms/spin.py index 717daaa53..f40611957 100644 --- a/nemo_aligner/algorithms/spin.py +++ b/nemo_aligner/algorithms/spin.py @@ -195,7 +195,7 @@ def train_single_step(self, global_batch): grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm lr = self.optimizer.param_groups[0]["lr"] - self.optimizer.step() + self.optimizer.step(closure=None) self.scheduler.step() trainer_metrics = {} diff --git a/nemo_aligner/algorithms/supervised.py b/nemo_aligner/algorithms/supervised.py index 3f2f67c61..ed3ce707d 100644 --- a/nemo_aligner/algorithms/supervised.py +++ b/nemo_aligner/algorithms/supervised.py @@ -150,7 +150,7 @@ def train_single_step(self, batch): grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm lr = self.optimizer.param_groups[0]["lr"] - self.optimizer.step() + self.optimizer.step(closure=None) self.scheduler.step() trainer_metrics = {} diff --git a/nemo_aligner/utils/train_utils.py b/nemo_aligner/utils/train_utils.py index da176b785..1883632cf 100644 --- a/nemo_aligner/utils/train_utils.py +++ b/nemo_aligner/utils/train_utils.py @@ -101,31 +101,52 @@ def prepare_for_training_step(ptl_model, zero_grad=True): param.data_ptr() +# TODO: Delete this once API introduced in NeMo (https://github.com/NVIDIA/NeMo/pull/10803) +# TODO: Update PR to move this logic into staticmethod in nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py def grad_reductions(ptl_model): # when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced if ptl_model.cfg.get("tensor_model_parallel_size", 1) > 1 and ptl_model.cfg.get("sequence_parallel", False): - ptl_model.allreduce_sequence_parallel_gradients() - - if ptl_model.with_distributed_adam: - # synchronize asynchronous grad reductions - # note: not necessary, but reduces performance degradation - # from multiple simultaneous NCCL calls - ptl_model._optimizer._finish_bucket_grad_sync() + # Mcore DistOpt handles this, so we don't have to + if not ptl_model.use_mcore_dist_optim: + ptl_model.megatron_timer_start("allreduce_sequence_parallel_gradients", log_level=1) + ptl_model.allreduce_sequence_parallel_gradients() + ptl_model.megatron_timer_stop("allreduce_sequence_parallel_gradients") + + ptl_model.megatron_timer_start("gradient_allreduce", log_level=1) + if ptl_model.use_fsdp: + # Reduce the gradients omitted from FSDP-sharding + ptl_model.allreduce_fsdp_sharding_omitted_gradients() + elif ptl_model.with_distributed_adam: + if not ptl_model.use_mcore_dist_optim: + # synchronize asynchronous grad reductions + # note: not necessary, but reduces performance degradation + # from multiple simultaneous NCCL calls + ptl_model._optimizer._finish_bucket_grad_sync() + # else: Mcore distributed optim calls finalize_model_grads to finish grad sync elif ptl_model.megatron_amp_O2: # when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously) - if ptl_model.cfg.get("pipeline_model_parallel_size", 1) > 1 or ptl_model.cfg.get("sequence_parallel", False): + if ( + ptl_model.cfg.get("pipeline_model_parallel_size", 1) > 1 + or ptl_model.cfg.get("sequence_parallel", False) + or not ptl_model.cfg.get("async_grad_allreduce", True) + ): # main grads are stored in the MainParamsOptimizer wrapper ptl_model._optimizer.allreduce_main_grads() else: # async grad allreduce is not currently implemented for O1/autocasting mixed precision training # so we all-reduce gradients after the pipeline ptl_model.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf) + ptl_model.megatron_timer_stop("gradient_allreduce") - if ptl_model.cfg.get("pipeline_model_parallel_size", 1) > 1 and ptl_model.cfg.get( - "share_embeddings_and_output_weights", True + if ( + not ptl_model.use_mcore_dist_optim + and ptl_model.cfg.get("pipeline_model_parallel_size", 1) > 1 + and ptl_model.cfg.get("share_embeddings_and_output_weights", True) ): + ptl_model.megatron_timer_start("allreduce_first_last_embeddings", log_level=1) # when using pipeline parallelism the first and last stage must keep embeddings in sync ptl_model.allreduce_first_last_embeddings() + ptl_model.megatron_timer_stop("allreduce_first_last_embeddings") def prepare_for_validation_step(ptl_model): @@ -155,7 +176,11 @@ def set_eval(ptl_model): ptl_model.eval() +# TODO: adapt the version in /opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_base_model.py def clip_gradients(ptl_model, clip_val): + """PTL hook to configure gradients. + We use gradient clipping implementation from megatron-lm. + """ if clip_val is None: return @@ -163,6 +188,14 @@ def clip_gradients(ptl_model, clip_val): if clip_val <= 0: return + if ptl_model.with_megatron_fused_adam or ptl_model.use_mcore_dist_optim: + # Gradient clipping is done in optimizer step + return + + if ptl_model.grad_clip_pl_default: + # use the default behavior + return super().configure_gradient_clipping(*args, **kwargs) + if ptl_model.with_distributed_adam: grad_norm = clip_grad_norm_distributed_optimizer(ptl_model._optimizer, clip_val) else: @@ -171,6 +204,5 @@ def clip_gradients(ptl_model, clip_val): parameters = ptl_model._optimizer.get_parameters_with_grad() else: parameters = ptl_model.get_parameters_with_grad() - grad_norm = clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val) - + grad_norm = clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val, use_fsdp=ptl_model.use_fsdp,) return grad_norm diff --git a/tests/functional/dpo.sh b/tests/functional/dpo.sh index bc073dcde..8db4dda8c 100755 --- a/tests/functional/dpo.sh +++ b/tests/functional/dpo.sh @@ -1,14 +1,26 @@ #!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) cd $SCRIPT_DIR set -eoux pipefail export NCCL_ALGO=Tree -export NVTE_APPLY_QK_LAYER_SCALING=1 +export NVTE_APPLY_QK_LAYER_SCALING=${NVTE_APPLY_QK_LAYER_SCALING:-0} -KL=${KL:-0.1} -GBS=${GBS:-4} PRETRAINED_CHECKPOINT_NEMO_FILE=${PRETRAINED_CHECKPOINT_NEMO_FILE} @@ -23,7 +35,6 @@ mkdir -p $RESULTS_DIR GPFS=$(git rev-parse --show-toplevel) -# START HETEROGENEUS JOB 3 CONF_DIR="${GPFS}/examples/nlp/gpt/conf/" CONF_NAME="gpt_dpo" @@ -33,38 +44,40 @@ dpo() { export CUDA_VISIBLE_DEVICES=0,1 export PYTHONPATH="${GPFS}:${PYTHONPATH:-}" export HYDRA_FULL_ERROR=1 -torchrun --nproc-per-node 2 ${GPFS}/examples/nlp/gpt/train_gpt_dpo.py \ +torchrun --nproc_per_node=2 ${GPFS}/examples/nlp/gpt/train_gpt_dpo.py \ --config-path=${CONF_DIR} \ --config-name=${CONF_NAME} \ trainer.num_nodes=1 \ trainer.devices=2 \ - ++model.data.data_impl=jsonl \ - ++model.data.seq_length=128 \ - ++model.global_batch_size=${GBS} \ + pretrained_checkpoint.restore_from_path=${PRETRAINED_CHECKPOINT_NEMO_FILE} \ + exp_manager.create_checkpoint_callback=False \ + exp_manager.explicit_log_dir=${RESULTS_DIR} \ + ++model.tensor_model_parallel_size=1 \ + ++model.pipeline_model_parallel_size=1 \ + ++model.global_batch_size=4 \ ++model.micro_batch_size=1 \ ++model.mcore_gpt=true \ ++model.megatron_amp_O2=true \ - ++model.dpo.ref_policy_kl_penalty=${KL} \ + ++model.dpo.ref_policy_kl_penalty=0.1 \ ++model.dpo.log_prob_forward_micro_batch_size=1 \ ++model.dpo.average_log_probs=false \ ++model.dpo.sft_loss_weight=0.1 \ ++model.dpo.preference_loss_weight=1.0 \ - pretrained_checkpoint.restore_from_path=${PRETRAINED_CHECKPOINT_NEMO_FILE} \ - "model.data.data_prefix={train: [${TRAIN_DATA_PATH}], validation: [${VALID_DATA_PATH}], test: [${VALID_DATA_PATH}]}" \ - exp_manager.create_checkpoint_callback=False \ + ++model.activations_checkpoint_granularity=full \ + ++model.activations_checkpoint_method=uniform \ + ++model.activations_checkpoint_num_layers=1 \ + ++model.dist_ckpt_load_strictness=log_all \ + ++model.data.data_impl=jsonl \ + ++model.data.seq_length=128 \ model.data.num_workers=2 \ - ++model.tensor_model_parallel_size=1 \ - ++model.pipeline_model_parallel_size=1 \ + "model.data.data_prefix={train: [${TRAIN_DATA_PATH}], validation: [${VALID_DATA_PATH}], test: [${VALID_DATA_PATH}]}" \ trainer.dpo.max_steps=3 \ trainer.dpo.val_check_interval=3 \ trainer.dpo.limit_val_batches=8 \ trainer.dpo.save_interval=0 \ - exp_manager.explicit_log_dir=${RESULTS_DIR} \ - ++model.activations_checkpoint_granularity=full \ - ++model.activations_checkpoint_method=uniform \ - ++model.activations_checkpoint_num_layers=1 \ - ++model.dist_ckpt_load_strictness=log_all + "$@" } log_file=$(mktemp /tmp/dpo-log-XXXXXX) -dpo | tee $log_file \ No newline at end of file +dpo "$@" | tee $log_file +echo "[Finished] $0" diff --git a/tests/functional/test_cases/dpo-llama3 b/tests/functional/test_cases/dpo-llama3 index 8e40e94c8..f841ab8b0 100755 --- a/tests/functional/test_cases/dpo-llama3 +++ b/tests/functional/test_cases/dpo-llama3 @@ -1,5 +1,6 @@ #!/bin/bash -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,4 +20,6 @@ cd $SCRIPT_DIR set -eoux pipefail PRETRAINED_CHECKPOINT_NEMO_FILE=${ALIGNER_CI_DIR}/checkpoints/tiny-llama3-results-nlayers2-hidden128-ffn448-nhead4-qgroup2-megatron_gpt.nemo \ -bash ../dpo.sh +bash ../dpo.sh \ + ++model.optim.name=mcore_distributed_optim \ + 2>&1 | tee $(basename $0).log diff --git a/tests/functional/test_cases/dpo-mixtral-ep b/tests/functional/test_cases/dpo-mixtral-ep new file mode 100755 index 000000000..79f6ffd1d --- /dev/null +++ b/tests/functional/test_cases/dpo-mixtral-ep @@ -0,0 +1,27 @@ +#!/bin/bash + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR + +set -eoux pipefail + +PRETRAINED_CHECKPOINT_NEMO_FILE=$ALIGNER_CI_DIR/checkpoints/tiny-mixtral-nlayers2-hidden128-ffn448-nhead4-qgroup2.nemo \ +bash ../dpo.sh \ + ++model.optim.name=mcore_distributed_optim \ + ++model.expert_model_parallel_size=2 \ + 2>&1 | tee $(basename $0).log + diff --git a/tests/functional/test_cases/dpo-mixtral-peft-tp-sp b/tests/functional/test_cases/dpo-mixtral-peft-tp-sp new file mode 100755 index 000000000..49e209b5f --- /dev/null +++ b/tests/functional/test_cases/dpo-mixtral-peft-tp-sp @@ -0,0 +1,41 @@ +#!/bin/bash + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR + +set -eoux pipefail + +args=( + model.peft.peft_scheme=lora + # MoE needs mcore dist opt + ++model.optim.name=mcore_distributed_optim + ++model.tensor_model_parallel_size=2 + ++model.expert_model_parallel_size=1 + # SP needed for TP>1 + ++model.sequence_parallel=True + # Seqlen % TP_SIZE when SP=True + model.data.pad_length_to_multiple_of=2 + ++model.tp_comm_overlap_disable_qkv=True + ++model.moe_token_dispatcher_type=alltoall + # TODO: Activation checkpointing is not currently functional with peft + ~model.activations_checkpoint_granularity + ~model.activations_checkpoint_method + ~model.activations_checkpoint_num_layers +) + +PRETRAINED_CHECKPOINT_NEMO_FILE=$ALIGNER_CI_DIR/checkpoints/tiny-mixtral-nlayers2-hidden128-ffn448-nhead4-qgroup2.nemo \ +bash ../dpo.sh "${args[@]}" 2>&1 | tee $(basename $0).log