Skip to content

Commit

Permalink
Fix (llm): change device to parameter device
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed May 31, 2024
1 parent da84c88 commit 3c4a674
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
5 changes: 2 additions & 3 deletions src/brevitas_examples/llm/llm_quant/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@ def create_validation_dataloader(data, seqlen, device):

@torch.no_grad()
def model_eval(model, valenc, seqlen):

nsamples = len(valenc)

dev = next(iter(model.parameters())).device
with torch.no_grad():
nlls = []
for inps in valenc:
lm_logits = model(**inps)['logits']
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = inps['input_ids'][:, 1:].to(model.device)
shift_labels = inps['input_ids'][:, 1:].to(dev)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
neg_log_likelihood = loss.float() * seqlen
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ def main():
nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, seed=0)
val_data = get_wikitext2(
nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, split='validation', seed=0)
val_data = create_validation_dataloader(val_data, args.seqlen, model.device)
device = next(iter(model.parameters())).device
val_data = create_validation_dataloader(val_data, args.seqlen, device)
print("Data loaded.")

# Apply LN affine merging before inserting MHA layers
Expand Down

0 comments on commit 3c4a674

Please sign in to comment.