From 059fdc1fa5a8342f3f29fcc69030be003c895a6f Mon Sep 17 00:00:00 2001 From: Ian Colbert Date: Mon, 26 Aug 2024 16:54:24 -0700 Subject: [PATCH 1/2] Feat (gptq): optimizing CPU to GPU memory transfer --- src/brevitas/graph/gptq.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 31d31433b..973919e03 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -132,7 +132,12 @@ def __init__( # Initialize Hessian matrix and counter. We need it in float32 to compute the inverse self.H = torch.zeros((self.groups, self.columns, self.columns), device='cpu', - dtype=torch.float32) + dtype=torch.float32, + pin_memory=True) + self.B = torch.zeros((self.groups, self.columns, self.columns), + device='cpu', + dtype=torch.float32, + pin_memory=True) self.nsamples = 0 assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher" @@ -184,7 +189,9 @@ def update_batch(self, module, input, current_layer): self.H *= self.nsamples / (self.nsamples + batch_size) self.nsamples += batch_size inp_processed = math.sqrt(2 / self.nsamples) * inp_processed.to(torch.float32) - self.H += (inp_processed.bmm(inp_processed.transpose(2, 1))).to(self.H.device) + # optimizing CPU to GPU transfer using in-place copy to pinned memory + self.B.copy_(inp_processed.bmm(inp_processed.transpose(2, 1))) + self.H += self.B # If we are executing GPTQ with group of parallel layers, we keep track of how many forward # we executed. Once we executed as many as the number of parallel_layers, we raise # StopFwdException @@ -255,7 +262,7 @@ def single_layer_update(self, percdamp=.01): f'Increasing the number of samples might fix this issue') return finally: - del self.H + del self.H, self.B for i1 in range(0, self.columns, self.blocksize): i2 = min(i1 + self.blocksize, self.columns) From 1f6432b2991405e4377b9615bfb71ffb18ae59a9 Mon Sep 17 00:00:00 2001 From: Ian Colbert Date: Mon, 26 Aug 2024 18:11:43 -0700 Subject: [PATCH 2/2] Fix (gptq): pin_memory only available with CUDA --- src/brevitas/graph/gptq.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 973919e03..a1380da4e 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -133,11 +133,11 @@ def __init__( self.H = torch.zeros((self.groups, self.columns, self.columns), device='cpu', dtype=torch.float32, - pin_memory=True) + pin_memory=torch.cuda.is_available()) self.B = torch.zeros((self.groups, self.columns, self.columns), device='cpu', dtype=torch.float32, - pin_memory=True) + pin_memory=torch.cuda.is_available()) self.nsamples = 0 assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher"