Skip to content

Commit

Permalink
Fix (gptq): pin_memory only available with CUDA
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Aug 27, 2024
1 parent 059fdc1 commit 1f6432b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ def __init__(
self.H = torch.zeros((self.groups, self.columns, self.columns),
device='cpu',
dtype=torch.float32,
pin_memory=True)
pin_memory=torch.cuda.is_available())
self.B = torch.zeros((self.groups, self.columns, self.columns),
device='cpu',
dtype=torch.float32,
pin_memory=True)
pin_memory=torch.cuda.is_available())
self.nsamples = 0

assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher"
Expand Down

0 comments on commit 1f6432b

Please sign in to comment.