From da84c88d36855d9eac1e70b34abdca1a6c117a52 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Tue, 30 Apr 2024 13:59:58 +0100 Subject: [PATCH 1/2] Fix (llm): fix device issue for eval when not using default device --- src/brevitas_examples/llm/llm_quant/eval.py | 6 +++--- src/brevitas_examples/llm/main.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/eval.py b/src/brevitas_examples/llm/llm_quant/eval.py index a9b2b4b64..b60d544c2 100644 --- a/src/brevitas_examples/llm/llm_quant/eval.py +++ b/src/brevitas_examples/llm/llm_quant/eval.py @@ -21,11 +21,11 @@ from tqdm import tqdm -def create_validation_dataloader(data, seqlen): +def create_validation_dataloader(data, seqlen, device): nsamples = data['input_ids'].numel() // seqlen val_dataloader = [] for i in tqdm(range(nsamples)): - batch = data['input_ids'][:, (i * seqlen):((i + 1) * seqlen)].cuda() + batch = data['input_ids'][:, (i * seqlen):((i + 1) * seqlen)].to(device) attention_mask = torch.ones_like(batch) val_dataloader.append({'input_ids': batch, 'attention_mask': attention_mask}) return val_dataloader @@ -41,7 +41,7 @@ def model_eval(model, valenc, seqlen): for inps in valenc: lm_logits = model(**inps)['logits'] shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = inps['input_ids'][:, 1:].cuda() + shift_labels = inps['input_ids'][:, 1:].to(model.device) loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) neg_log_likelihood = loss.float() * seqlen diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index e86a8d4ba..b033f0e26 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -278,7 +278,7 @@ def main(): nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, seed=0) val_data = get_wikitext2( nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, split='validation', seed=0) - val_data = create_validation_dataloader(val_data, args.seqlen) + val_data = create_validation_dataloader(val_data, args.seqlen, model.device) print("Data loaded.") # Apply LN affine merging before inserting MHA layers From 3c4a674ffca736b806c348ab8b16ad227b39a426 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Fri, 31 May 2024 11:16:50 +0100 Subject: [PATCH 2/2] Fix (llm): change device to parameter device --- src/brevitas_examples/llm/llm_quant/eval.py | 5 ++--- src/brevitas_examples/llm/main.py | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/eval.py b/src/brevitas_examples/llm/llm_quant/eval.py index b60d544c2..271a5b36e 100644 --- a/src/brevitas_examples/llm/llm_quant/eval.py +++ b/src/brevitas_examples/llm/llm_quant/eval.py @@ -33,15 +33,14 @@ def create_validation_dataloader(data, seqlen, device): @torch.no_grad() def model_eval(model, valenc, seqlen): - nsamples = len(valenc) - + dev = next(iter(model.parameters())).device with torch.no_grad(): nlls = [] for inps in valenc: lm_logits = model(**inps)['logits'] shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = inps['input_ids'][:, 1:].to(model.device) + shift_labels = inps['input_ids'][:, 1:].to(dev) loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) neg_log_likelihood = loss.float() * seqlen diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index b033f0e26..5237c31c7 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -278,7 +278,8 @@ def main(): nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, seed=0) val_data = get_wikitext2( nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, split='validation', seed=0) - val_data = create_validation_dataloader(val_data, args.seqlen, model.device) + device = next(iter(model.parameters())).device + val_data = create_validation_dataloader(val_data, args.seqlen, device) print("Data loaded.") # Apply LN affine merging before inserting MHA layers