Skip to content

Commit

Permalink
Merge pull request #198 from gkumbhat/fix_dtype_prompt_tuning
Browse files Browse the repository at this point in the history
Fix dtype prompt tuning
  • Loading branch information
gkumbhat authored Sep 19, 2023
2 parents 86315fc + 8b73e37 commit 6f09931
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 10 deletions.
59 changes: 49 additions & 10 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def train(
base_model.model.config.d_model = 1024

peft_model = get_peft_model(base_model.model, peft_config)

# Convert our Peft model (not just the underlying
# transformers model) to the right underlying type.
device = cls._get_device(device)
Expand All @@ -406,6 +407,7 @@ def train(
tokenizer=base_model.tokenizer,
accumulate_steps=accumulate_steps,
silence_progress_bars=silence_progress_bars,
torch_dtype=torch_dtype,
)

# Get config of the base model
Expand Down Expand Up @@ -963,6 +965,7 @@ def _execute_train_loop(
tokenizer: Union[AutoTokenizer, None] = None,
accumulate_steps: int = 1,
silence_progress_bars: bool = True,
torch_dtype: "torch.dtype" = torch.float32,
) -> None:
"""Execute the core training logic for training the prompt vectors on the frozen model.
Note that this is done by reference.
Expand Down Expand Up @@ -991,6 +994,8 @@ def _execute_train_loop(
Number of steps to use for gradient accumulation. Default: 1.
silence_progress_bars: bool
Silences TQDM progress bars. Default: True
torch_dtype: torch.dtype
Dtype to be used for training. Default: torch.float32
Returns:
training_metadata: Dict
Expand All @@ -1003,8 +1008,36 @@ def _execute_train_loop(
num_training_steps=(len(train_dataloader) * num_epochs),
)

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

if torch_dtype == torch.float16:
mixed_precision = "fp16"
elif (
torch.cuda.is_available()
and torch.cuda.is_bf16_supported()
and torch_dtype == torch.bfloat16
):
mixed_precision = "bf16"
else:
mixed_precision = "no"

accelerator = Accelerator(
gradient_accumulation_steps=accumulate_steps, device_placement=True
gradient_accumulation_steps=accumulate_steps,
device_placement=True,
mixed_precision=mixed_precision,
)

# Disable cache for training
model.config.use_cache = False

# Below would send all the data and model to
# configured device and convert them to required dtypes
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model,
optimizer,
train_dataloader,
lr_scheduler,
)

training_loss_tracker = []
Expand All @@ -1014,18 +1047,24 @@ def _execute_train_loop(
total_loss = 0
tqdm_loader = tqdm(train_dataloader, disable=silence_progress_bars)
for batch in tqdm_loader:

tqdm_loader.set_description("Epoch: {}".format(epoch))

# TODO Can this dict comprehension always replace "batch.to(device)" for us?
batch = {k: v.to(device) for k, v in batch.items()}
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
total_loss += loss.detach().float()
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
try:
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
total_loss += loss.detach().float()
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
except torch.cuda.OutOfMemoryError:
error(
"<NLP07175292E>",
MemoryError("Not enough memory available for training!"),
)

log.info("<NLP46114010I>", {"loss": float(loss), "epoch": epoch})
# Below is added to be propagated and stored as training_metadata
Expand Down
7 changes: 7 additions & 0 deletions examples/run_peft_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,12 @@ def register_common_arguments(subparsers: Tuple[argparse.ArgumentParser]) -> Non
default=1,
type=int,
)
subparser.add_argument(
"--torch_dtype",
help="Torch dtype to be used for training",
default="float32",
choices=["float16", "bfloat16", "float32"],
)


def register_multitask_prompt_tuning_args(subparser: argparse.ArgumentParser):
Expand Down Expand Up @@ -407,6 +413,7 @@ def show_experiment_configuration(args, dataset_info, model_type) -> None:
verbalizer=dataset_info.verbalizer,
silence_progress_bars=not args.verbose,
accumulate_steps=args.accumulate_steps,
torch_dtype=args.torch_dtype,
)
model.save(args.output_dir, save_base_model=not args.prompt_only)
print_colored("[Training Complete]")

0 comments on commit 6f09931

Please sign in to comment.