Skip to content

Commit

Permalink
GPxQ for SD speed-up
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed May 2, 2024
1 parent 5c4968b commit 1c32492
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 30 deletions.
5 changes: 5 additions & 0 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ def single_layer_update(self):
# No permutation, permutation tensor is a ordered index
perm = torch.tensor(range(weight.shape[-1]), device=dev)
permutation_list.append(perm)

self.reactivate_quantization()

for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
Expand Down Expand Up @@ -358,6 +361,8 @@ def single_layer_update(self):
perm = torch.tensor(range(weight.shape[-1]), device=dev)
permutation_list.append(perm)

self.reactivate_quantization()

for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
Expand Down
8 changes: 8 additions & 0 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from packaging import version
import torch

from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import restore_return_quant_tensor

try:
from torch.linalg import LinAlgError
except:
Expand Down Expand Up @@ -86,9 +89,11 @@ def catch_stopfwd(self, *args, **kwargs):
# If we want to return the output of the network, we need to disable all hooks
for name, gpxq_class in self.gpxq_layers.items():
gpxq_class.disable_pre_forward_hook = True

out = self.orig_forward(*args, **kwargs)
for name, gpxq_class in self.gpxq_layers.items():
gpxq_class.disable_pre_forward_hook = False

return out

def initialize_module_optimizer(
Expand Down Expand Up @@ -134,6 +139,7 @@ def __init__(
device='cpu',
dtype=torch.float32)
self.nsamples = 0
self.done = False

assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher"

Expand Down Expand Up @@ -257,6 +263,8 @@ def single_layer_update(self, percdamp=.01):
finally:
del self.H

self.reactivate_quantization()

for i1 in range(0, self.columns, self.blocksize):
i2 = min(i1 + self.blocksize, self.columns)
count = i2 - i1
Expand Down
13 changes: 13 additions & 0 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ def __enter__(self):
return self

def __exit__(self, type, value, traceback):
for name, layer in self.gpxq_layers.items():
if not layer.done:
layer.reactivate_quantization()

if isinstance(self.model, (GraphModule, TorchGraphModule)):
self.model.__class__.forward = self.orig_forward
else:
Expand Down Expand Up @@ -219,6 +223,10 @@ def __init__(
self.disable_pre_forward_hook = False
# Some layers require knowledge from quant inputs to compute quant weights
self.quant_metadata = None
self.disable_quant_inference = DisableEnableQuantization()
self.return_quant_tensor_state = disable_return_quant_tensor(self.layer)
self.disable_quant_inference.disable_param_quantization(self.layer, False)
self.done = False

def process_input(self, inp):
# Input is a tuple, so we take first element
Expand Down Expand Up @@ -255,6 +263,11 @@ def update_batch(self):
def single_layer_update(self):
pass

def reactivate_quantization(self):
self.done = True
self.disable_quant_inference.enable_param_quantization(self.layer, False)
restore_return_quant_tensor(self.layer, self.return_quant_tensor_state)

def get_quant_weights(self, i, i1, permutation_list):
# We need to recompute quant weights at runtime since our float weights are being updated
# Add offset in case of blockwise computation
Expand Down
3 changes: 3 additions & 0 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
import re

import torch
from torch import nn

from brevitas import nn as qnn
Expand Down Expand Up @@ -190,6 +191,8 @@ def quantize_model(
'narrow_range': False,
'quantize_zero_point': quantize_weight_zero_point},
**weight_float_format)
if dtype == torch.float16:
weight_quant = weight_quant.let(**{'scaling_min_val': 1e-4})

# Set the group_size is we're doing groupwise quantization
if weight_quant_granularity == 'per_group':
Expand Down
37 changes: 7 additions & 30 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,6 @@
'close up photo of a rabbit, forest in spring, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot'
}

clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")


def calculate_clip_score(images, prompts):
images_int = (np.array(images)).astype("uint8")
clip_score = clip_score_fn(torch.from_numpy(images_int).permute(2, 0, 1), prompts).detach()
return round(float(clip_score), 4)


def run_test_inference(
pipe, resolution, prompts, seeds, output_path, device, dtype, name_prefix=''):
Expand Down Expand Up @@ -174,7 +166,11 @@ def main(args):
for m in pipe.unet.modules():
if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'):
m.in_features = m.module.in_features
prompts = VALIDATION_PROMPTS
prompts = dict()
for i, (k, v) in enumerate(CALIBRATION_PROMPTS.items()):
if i == args.calibration_prompt:
break
prompts[k] = v
run_val_inference(
pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype)

Expand Down Expand Up @@ -310,25 +306,6 @@ def input_bit_width(module):
fid.update(quant_images_values, real=False)
print(f"FID: {float(fid.compute())}")

clip_score_float = 0.
for prompt, images in float_images.items():
clip_score_prompt = 0.
for image in images:
clip_score_prompt += calculate_clip_score(image, prompt)
clip_score_float += clip_score_prompt / len(images)
clip_score_float /= len(float_images)

clip_score_quant = 0.
for prompt, images in quant_images.items():
clip_score_prompt = 0.
for image in images:
clip_score_prompt += calculate_clip_score(image, prompt)
clip_score_quant += clip_score_prompt / len(images)
clip_score_quant /= len(quant_images)

print(f"CLIP float: {clip_score_float}")
print(f"CLIP quant: {clip_score_quant}")

if args.export_target:
# Move to cpu and to float32 to enable CPU export
if not (dtype == torch.float16 and args.export_cuda_float16):
Expand Down Expand Up @@ -373,8 +350,8 @@ def input_bit_width(module):
'-b',
'--batch-size',
type=int,
default=4,
help='How many seeds to use for each image during validation. Default: 4')
default=2,
help='How many seeds to use for each image during validation. Default: 2')
parser.add_argument(
'--prompt',
type=int,
Expand Down

0 comments on commit 1c32492

Please sign in to comment.