Skip to content

Commit

Permalink
remove gpxq changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed May 31, 2024
1 parent 6f7d55f commit 724842c
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 23 deletions.
5 changes: 0 additions & 5 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,6 @@ def single_layer_update(self):
# No permutation, permutation tensor is a ordered index
perm = torch.tensor(range(weight.shape[-1]), device=dev)
permutation_list.append(perm)

self.reactivate_quantization()

for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
Expand Down Expand Up @@ -361,8 +358,6 @@ def single_layer_update(self):
perm = torch.tensor(range(weight.shape[-1]), device=dev)
permutation_list.append(perm)

self.reactivate_quantization()

for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
Expand Down
5 changes: 0 additions & 5 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,9 @@ def catch_stopfwd(self, *args, **kwargs):
# If we want to return the output of the network, we need to disable all hooks
for name, gpxq_class in self.gpxq_layers.items():
gpxq_class.disable_pre_forward_hook = True

out = self.orig_forward(*args, **kwargs)
for name, gpxq_class in self.gpxq_layers.items():
gpxq_class.disable_pre_forward_hook = False

return out

def initialize_module_optimizer(
Expand Down Expand Up @@ -136,7 +134,6 @@ def __init__(
device='cpu',
dtype=torch.float32)
self.nsamples = 0
self.done = False

assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher"

Expand Down Expand Up @@ -260,8 +257,6 @@ def single_layer_update(self, percdamp=.01):
finally:
del self.H

self.reactivate_quantization()

for i1 in range(0, self.columns, self.blocksize):
i2 = min(i1 + self.blocksize, self.columns)
count = i2 - i1
Expand Down
13 changes: 0 additions & 13 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,6 @@ def __enter__(self):
return self

def __exit__(self, type, value, traceback):
for name, layer in self.gpxq_layers.items():
if not layer.done:
layer.reactivate_quantization()

if isinstance(self.model, (GraphModule, TorchGraphModule)):
self.model.__class__.forward = self.orig_forward
else:
Expand Down Expand Up @@ -223,10 +219,6 @@ def __init__(
self.disable_pre_forward_hook = False
# Some layers require knowledge from quant inputs to compute quant weights
self.quant_metadata = None
self.disable_quant_inference = DisableEnableQuantization()
self.return_quant_tensor_state = disable_return_quant_tensor(self.layer)
self.disable_quant_inference.disable_param_quantization(self.layer, False)
self.done = False

def process_input(self, inp):
# Input is a tuple, so we take first element
Expand Down Expand Up @@ -263,11 +255,6 @@ def update_batch(self):
def single_layer_update(self):
pass

def reactivate_quantization(self):
self.done = True
self.disable_quant_inference.enable_param_quantization(self.layer, False)
restore_return_quant_tensor(self.layer, self.return_quant_tensor_state)

def get_quant_weights(self, i, i1, permutation_list):
# We need to recompute quant weights at runtime since our float weights are being updated
# Add offset in case of blockwise computation
Expand Down

0 comments on commit 724842c

Please sign in to comment.