Skip to content

Commit

Permalink
Move check to prevent running peft with VP to a more correct phase of…
Browse files Browse the repository at this point in the history
… setup (NVIDIA#8216)

Signed-off-by: Valerie Sarge <vsarge@nvidia.com>
  • Loading branch information
vysarge authored Jan 23, 2024
1 parent a39f526 commit a44b75d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
if hasattr(self.cfg.data.test_ds, "metric"):
self.test_metric_label_key = self.cfg.data.test_ds.metric.get('label_key', 'labels')

if self.use_peft and self.cfg.get('virtual_pipeline_model_parallel_size', None):
raise ValueError('Virtual pipeline model parallel is not supported when using PEFT')

# Set the profile start and end steps in the unit of global batach
if hasattr(self, '_nsys_profile_enabled'):
self._nsys_profile_start_step = self.cfg.nsys_profile.get('start_step', 0)
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ def add_adapter(self, peft_cfgs: Union[PEFTConfig, List[PEFTConfig]]):
peft_cfgs: One or more PEFTConfig objects that specify the PEFT method configuration
"""

if self.cfg.get('virtual_pipeline_model_parallel_size', None):
raise ValueError('Virtual pipeline model parallel is not supported when using PEFT')

if not isinstance(peft_cfgs, List):
peft_cfgs = [peft_cfgs]

Expand Down

0 comments on commit a44b75d

Please sign in to comment.